diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index de10e4b2d4..f19ff82a0f 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -233,9 +233,11 @@ def __init__( self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) self.num_bins = num_bins self.kernel_type = kernel_type + self.register_buffer("preterm", None, persistent=False) + self.register_buffer("bin_centers", None, persistent=False) if self.kernel_type == "gaussian": - self.preterm = 1 / (2 * sigma**2) - self.bin_centers = bin_centers[None, None, ...] + self.register_buffer("preterm", 1 / (2 * sigma**2), persistent=False) + self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) diff --git a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py index ff7851ed1c..f6c81bcc9c 100644 --- a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +++ b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py @@ -145,5 +145,45 @@ def test_ill_opts(self, num_bins, reduction, expected_exception, expected_messag GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target) +class TestGlobalMutualInformationLossBuffers(unittest.TestCase): + def test_gaussian_kernel_registers_buffers(self): + """preterm and bin_centers are registered as non-persistent buffers for gaussian kernel.""" + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + self.assertIn("preterm", loss._buffers) + self.assertIn("bin_centers", loss._buffers) + self.assertFalse(loss.preterm.requires_grad) + self.assertFalse(loss.bin_centers.requires_grad) + self.assertEqual(loss.bin_centers.ndim, 3) + + def test_bspline_kernel_has_no_gaussian_buffers(self): + """b-spline kernel does not register gaussian-specific buffers.""" + loss = GlobalMutualInformationLoss(kernel_type="b-spline") + self.assertIsNone(loss.preterm) + self.assertIsNone(loss.bin_centers) + + def test_gaussian_kernel_forward_correct(self): + """Gaussian kernel forward pass returns a scalar loss.""" + pred = torch.rand(2, 1, 8, 8, dtype=torch.float32) + target = torch.rand(2, 1, 8, 8, dtype=torch.float32) + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + result = loss(pred, target) + self.assertEqual(result.shape, torch.Size([])) + + def test_gaussian_buffers_move_with_module(self): + """Buffers move to the correct device when the module is moved with .to().""" + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + self.assertEqual(loss.preterm.device.type, "cpu") + self.assertEqual(loss.bin_centers.device.type, "cpu") + if not torch.cuda.is_available(): + return + loss = loss.cuda() + self.assertEqual(loss.preterm.device.type, "cuda") + self.assertEqual(loss.bin_centers.device.type, "cuda") + pred = torch.rand(2, 1, 8, 8, device="cuda") + target = torch.rand(2, 1, 8, 8, device="cuda") + result = loss(pred, target) + self.assertEqual(result.device.type, "cuda") + + if __name__ == "__main__": unittest.main()