From 23109a7f48d83a518e1d6c9a04e0614d26506682 Mon Sep 17 00:00:00 2001 From: Oleksandr Sanin Date: Wed, 20 May 2026 09:29:26 +0000 Subject: [PATCH 1/2] fix(losses): register buffers in GlobalMutualInformationLoss When kernel_type="gaussian", `preterm` and `bin_centers` were stored as plain tensor attributes via simple assignment. This means they are not registered in PyTorch's module buffer system, so calling `loss.to("cuda")` or `loss.cuda()` does not move these tensors to the target device. Each forward pass had to call `.to(img)` to patch the device mismatch at runtime, which is both redundant and misleading. Use `register_buffer(..., persistent=False)` so that both tensors are properly tracked by the module and automatically move with `.to()` / `.cuda()` / `.cpu()` calls, consistent with the pattern already used by `LocalNormalizedCrossCorrelationLoss`. The `.to(img)` calls in `parzen_windowing_gaussian` are retained for dtype coercion (e.g. float16 inference). Adds `TestGlobalMutualInformationLossBuffers` to verify buffer registration and that b-spline mode does not create gaussian buffers. Closes #8819 Signed-off-by: Oleksandr Sanin --- monai/losses/image_dissimilarity.py | 6 +++-- .../test_global_mutual_information_loss.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index de10e4b2d4..9fa439e86e 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -233,9 +233,11 @@ def __init__( self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) self.num_bins = num_bins self.kernel_type = kernel_type + self.preterm: torch.Tensor + self.bin_centers: torch.Tensor if self.kernel_type == "gaussian": - self.preterm = 1 / (2 * sigma**2) - self.bin_centers = bin_centers[None, None, ...] + self.register_buffer("preterm", 1 / (2 * sigma**2), persistent=False) + self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) diff --git a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py index ff7851ed1c..cb93bfa992 100644 --- a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +++ b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py @@ -145,5 +145,28 @@ def test_ill_opts(self, num_bins, reduction, expected_exception, expected_messag GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target) +class TestGlobalMutualInformationLossBuffers(unittest.TestCase): + def test_gaussian_kernel_registers_buffers(self): + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + # preterm and bin_centers must be registered buffers so .to() moves them + self.assertIn("preterm", loss._buffers) + self.assertIn("bin_centers", loss._buffers) + self.assertFalse(loss.preterm.requires_grad) + self.assertFalse(loss.bin_centers.requires_grad) + self.assertEqual(loss.bin_centers.ndim, 3) + + def test_bspline_kernel_has_no_gaussian_buffers(self): + loss = GlobalMutualInformationLoss(kernel_type="b-spline") + self.assertNotIn("preterm", loss._buffers) + self.assertNotIn("bin_centers", loss._buffers) + + def test_gaussian_kernel_forward_correct(self): + pred = torch.rand(2, 1, 8, 8, dtype=torch.float32) + target = torch.rand(2, 1, 8, 8, dtype=torch.float32) + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + result = loss(pred, target) + self.assertEqual(result.shape, torch.Size([])) + + if __name__ == "__main__": unittest.main() From 379b8a8c93ae05b7201c41a0fe4fa352534ae310 Mon Sep 17 00:00:00 2001 From: Oleksandr Sanin Date: Mon, 25 May 2026 17:33:11 +0000 Subject: [PATCH 2/2] fix(losses): address review comments on GlobalMutualInformationLoss buffer fix Use register_buffer("preterm", None) / register_buffer("bin_centers", None) unconditionally so that both buffers are always present in _buffers (with None for b-spline). This avoids a KeyError that occurred when plain instance attribute assignment conflicted with a subsequent register_buffer call. Also add docstrings to the new test methods and a device-movement test that verifies buffers follow the module when .cuda() is called. Signed-off-by: Oleksandr Sanin --- monai/losses/image_dissimilarity.py | 4 ++-- .../test_global_mutual_information_loss.py | 23 ++++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 9fa439e86e..f19ff82a0f 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -233,8 +233,8 @@ def __init__( self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) self.num_bins = num_bins self.kernel_type = kernel_type - self.preterm: torch.Tensor - self.bin_centers: torch.Tensor + self.register_buffer("preterm", None, persistent=False) + self.register_buffer("bin_centers", None, persistent=False) if self.kernel_type == "gaussian": self.register_buffer("preterm", 1 / (2 * sigma**2), persistent=False) self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) diff --git a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py index cb93bfa992..f6c81bcc9c 100644 --- a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +++ b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py @@ -147,8 +147,8 @@ def test_ill_opts(self, num_bins, reduction, expected_exception, expected_messag class TestGlobalMutualInformationLossBuffers(unittest.TestCase): def test_gaussian_kernel_registers_buffers(self): + """preterm and bin_centers are registered as non-persistent buffers for gaussian kernel.""" loss = GlobalMutualInformationLoss(kernel_type="gaussian") - # preterm and bin_centers must be registered buffers so .to() moves them self.assertIn("preterm", loss._buffers) self.assertIn("bin_centers", loss._buffers) self.assertFalse(loss.preterm.requires_grad) @@ -156,17 +156,34 @@ def test_gaussian_kernel_registers_buffers(self): self.assertEqual(loss.bin_centers.ndim, 3) def test_bspline_kernel_has_no_gaussian_buffers(self): + """b-spline kernel does not register gaussian-specific buffers.""" loss = GlobalMutualInformationLoss(kernel_type="b-spline") - self.assertNotIn("preterm", loss._buffers) - self.assertNotIn("bin_centers", loss._buffers) + self.assertIsNone(loss.preterm) + self.assertIsNone(loss.bin_centers) def test_gaussian_kernel_forward_correct(self): + """Gaussian kernel forward pass returns a scalar loss.""" pred = torch.rand(2, 1, 8, 8, dtype=torch.float32) target = torch.rand(2, 1, 8, 8, dtype=torch.float32) loss = GlobalMutualInformationLoss(kernel_type="gaussian") result = loss(pred, target) self.assertEqual(result.shape, torch.Size([])) + def test_gaussian_buffers_move_with_module(self): + """Buffers move to the correct device when the module is moved with .to().""" + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + self.assertEqual(loss.preterm.device.type, "cpu") + self.assertEqual(loss.bin_centers.device.type, "cpu") + if not torch.cuda.is_available(): + return + loss = loss.cuda() + self.assertEqual(loss.preterm.device.type, "cuda") + self.assertEqual(loss.bin_centers.device.type, "cuda") + pred = torch.rand(2, 1, 8, 8, device="cuda") + target = torch.rand(2, 1, 8, 8, device="cuda") + result = loss(pred, target) + self.assertEqual(result.device.type, "cuda") + if __name__ == "__main__": unittest.main()