Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b1e4a50
Perceptual loss changes.
Nov 26, 2025
0aeb4d9
Merge branch '8627-perceptual-loss-errors-out-after-hitting-the-maxim…
Nov 26, 2025
fa0639b
Fixes
Nov 26, 2025
685aee2
Merge branch 'dev' into 8627-perceptual-loss-errors-out-after-hitting…
virginiafdez Dec 2, 2025
915de5f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 2, 2025
5594bfe
Unnecessary import
Nov 26, 2025
b065de7
Merge branch '8627-perceptual-loss-errors-out-after-hitting-the-maxim…
Dec 5, 2025
c99e16e
Add check of network name
Dec 5, 2025
717b99b
Update monai/losses/perceptual.py
ericspod Dec 5, 2025
b276f3c
Update monai/losses/perceptual.py
ericspod Dec 5, 2025
2156b84
Update monai/losses/perceptual.py
ericspod Dec 5, 2025
e2b982e
Update monai/losses/perceptual.py
virginiafdez Dec 9, 2025
e3be8de
Bug
Dec 9, 2025
b02053b
Merge branch '8627-perceptual-loss-errors-out-after-hitting-the-maxim…
Dec 9, 2025
6dfc209
DCO Remediation Commit for Virginia Fernandez <virginia.fernandez@kcl…
Dec 9, 2025
d258390
Reformatting
Dec 9, 2025
2520920
fix no space left issue
KumoLiu Dec 26, 2025
d2ab308
fix typo in error message
KumoLiu Dec 26, 2025
081a673
try fix no space left in packaging pipeline
KumoLiu Dec 26, 2025
cff03d6
fix setup
KumoLiu Dec 26, 2025
f140c16
Merge branch 'dev' into 8627-perceptual-loss-errors-out-after-hitting…
garciadias Mar 19, 2026
11cd023
Merge branch 'dev' into 8627-perceptual-loss-errors-out-after-hitting…
ericspod Mar 30, 2026
2eb05ff
Merge branch 'dev' into 8627-perceptual-loss-errors-out-after-hitting…
ericspod Mar 30, 2026
85a5030
Removing old file
ericspod May 11, 2026
ef3ae98
Merge branch 'dev' into 8627-perceptual-loss-errors-out-after-hitting…
ericspod May 11, 2026
5094b3d
Merge branch 'dev' into 8627-perceptual-loss-errors-out-after-hitting…
virginiafdez May 12, 2026
76f5432
Fix lower() bug
May 21, 2026
44ebf2c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2026
b07e9fc
stack_level and suggested tests
May 21, 2026
cf3941a
Merge branch '8627-perceptual-loss-errors-out-after-hitting-the-maxim…
May 21, 2026
0c4185f
Merge branch 'dev' into 8627-perceptual-loss-errors-out-after-hitting…
virginiafdez May 21, 2026
7ff0b86
remove | operator on Enum
May 21, 2026
3e34a8f
Merge branch '8627-perceptual-loss-errors-out-after-hitting-the-maxim…
May 21, 2026
fb49109
autofix run
May 21, 2026
cfbdaaf
remove stacklevel
May 21, 2026
0de0889
ruff
May 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@
from monai.utils import optional_import
from monai.utils.enums import StrEnum

# Valid model name to download from the repository
HF_MONAI_MODELS = frozenset(
("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50")
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

LPIPS, _ = optional_import("lpips", name="LPIPS")
torchvision, _ = optional_import("torchvision")


class PercetualNetworkType(StrEnum):
class PerceptualNetworkType(StrEnum):
"""Types of neural networks that are supported by perceptual loss."""

alex = "alex"
vgg = "vgg"
squeeze = "squeeze"
Expand Down Expand Up @@ -81,7 +88,7 @@ class PerceptualLoss(nn.Module):
def __init__(
self,
spatial_dims: int,
network_type: str = PercetualNetworkType.alex,
network_type: str = PerceptualNetworkType.alex,
is_fake_3d: bool = True,
fake_3d_ratio: float = 0.5,
cache_dir: str | None = None,
Expand All @@ -95,18 +102,26 @@ def __init__(
if spatial_dims not in [2, 3]:
raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.")

if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type:
raise ValueError(
"MedicalNet networks are only compatible with ``spatial_dims=3``."
"Argument is_fake_3d must be set to False."
)

if channel_wise and "medicalnet_" not in network_type:
network_type = network_type.lower()

# Strict validation for MedicalNet
if "medicalnet_" in network_type:
if spatial_dims == 2 or is_fake_3d:
Comment thread
virginiafdez marked this conversation as resolved.
raise ValueError(
"MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False."
)
if not channel_wise:
warnings.warn(
"MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2
)

# Channel-wise only for MedicalNet
elif channel_wise:
raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.")

if network_type.lower() not in list(PercetualNetworkType):
if network_type.lower() not in list(PerceptualNetworkType):
raise ValueError(
f"Unrecognised criterion entered for Perceptual Loss. Must be one in: {', '.join(PercetualNetworkType)}"
f"Unrecognised criterion entered for Perceptual Loss. Must be one in: {', '.join(PerceptualNetworkType)}"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if cache_dir:
torch.hub.set_dir(cache_dir)
Expand All @@ -117,12 +132,16 @@ def __init__(

self.spatial_dims = spatial_dims
self.perceptual_function: nn.Module

# If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used.
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualSimilarity(
net=network_type, verbose=False, channel_wise=channel_wise
net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir
)
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
self.perceptual_function = RadImageNetPerceptualSimilarity(
net=network_type, verbose=False, cache_dir=cache_dir
)
elif network_type == "resnet50":
self.perceptual_function = TorchvisionModelPerceptualSimilarity(
net=network_type,
Expand All @@ -131,7 +150,9 @@ def __init__(
pretrained_state_dict_key=pretrained_state_dict_key,
)
else:
# VGG, AlexNet and SqueezeNet are independently handled by LPIPS.
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)

self.is_fake_3d = is_fake_3d
self.fake_3d_ratio = fake_3d_ratio
self.channel_wise = channel_wise
Expand Down Expand Up @@ -203,22 +224,31 @@ class MedicalNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
"Warvito/MedicalNet-models".
"Project-MONAI/perceptual-models".

Args:
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
verbose: if false, mute messages from torch Hub load function.
channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
Defaults to ``False``.
Defaults to ``False``.
cache_dir: path to cache directory to save the pretrained network weights.
"""

def __init__(
Comment thread
virginiafdez marked this conversation as resolved.
self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False
self,
net: str = "medicalnet_resnet10_23datasets",
verbose: bool = False,
channel_wise: bool = False,
cache_dir: str | None = None,
) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True)
if net not in HF_MONAI_MODELS:
raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.")

self.model = torch.hub.load(
Comment thread
virginiafdez marked this conversation as resolved.
"Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True
)
self.eval()

self.channel_wise = channel_wise
Expand Down Expand Up @@ -267,7 +297,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
for i in range(input.shape[1]):
l_idx = i * feats_per_ch
r_idx = (i + 1) * feats_per_ch
results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1)
results[:, i, ...] = feats_diff[:, l_idx:r_idx, ...].sum(dim=1)
else:
results = feats_diff.sum(dim=1, keepdim=True)

Expand Down Expand Up @@ -296,17 +326,22 @@ class RadImageNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
uses torch Hub to download the networks from "Warvito/radimagenet-models".
uses torch Hub to download the networks from "Project-MONAI/perceptual-models".

Args:
net: {``"radimagenet_resnet50"``}
Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``.
verbose: if false, mute messages from torch Hub load function.
cache_dir: path to cache directory to save the pretrained network weights.
"""

def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None:
super().__init__()
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True)
if net not in HF_MONAI_MODELS:
raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.")
self.model = torch.hub.load(
"Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True
)
self.eval()

for param in self.parameters():
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import sys
import warnings
from typing import Any, cast

from packaging import version
from setuptools import find_packages, setup
Expand Down Expand Up @@ -146,6 +147,6 @@ def get_cmds():
cmdclass=get_cmds(),
packages=find_packages(exclude=("docs", "examples", "tests", "tests.*")),
zip_safe=False,
package_data={"monai": ["py.typed", *jit_extension_source]}, # type: ignore[arg-type]
package_data=cast(Any, {"monai": ["py.typed", *jit_extension_source]}),
ext_modules=get_extensions(),
)
10 changes: 10 additions & 0 deletions tests/losses/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ def test_medicalnet_on_2d_data(self, network_type):
with self.assertRaises(ValueError):
PerceptualLoss(spatial_dims=2, network_type=network_type)

@parameterized.expand(["squeeze", "alex", "vgg", "radimagenet_resnet50", "resnet50"])
def test_channel_wise_with_non_medicalnet(self, network_type):
with self.assertRaises(ValueError):
PerceptualLoss(spatial_dims=2, network_type=network_type, channel_wise=True)

@parameterized.expand(["squeeze", "alex", "vgg", "radimagenet_resnet50", "resnet50"])
def test_non_medicalnet_3d_without_fake_3d(self, network_type):
with self.assertRaises(ValueError):
PerceptualLoss(spatial_dims=3, network_type=network_type, is_fake_3d=False)


if __name__ == "__main__":
unittest.main()
Loading