diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index de10e4b2d4..1d1acafba1 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -233,9 +233,15 @@ def __init__( self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) self.num_bins = num_bins self.kernel_type = kernel_type + # declared as buffers so they move with the module (e.g. ``.to(device)``); only populated for the + # gaussian kernel, hence the ``Tensor`` annotation reflects the type at the use sites in that path. + self.preterm: torch.Tensor + self.bin_centers: torch.Tensor + self.register_buffer("preterm", None, persistent=False) + self.register_buffer("bin_centers", None, persistent=False) if self.kernel_type == "gaussian": - self.preterm = 1 / (2 * sigma**2) - self.bin_centers = bin_centers[None, None, ...] + self.register_buffer("preterm", 1 / (2 * sigma**2), persistent=False) + self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) diff --git a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py index ff7851ed1c..07e605d669 100644 --- a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +++ b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py @@ -145,5 +145,51 @@ def test_ill_opts(self, num_bins, reduction, expected_exception, expected_messag GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target) +class TestGlobalMutualInformationLossBuffers(unittest.TestCase): + def test_gaussian_kernel_registers_buffers(self): + """Verify gaussian kernel registers preterm and bin_centers as non-trainable, non-persistent buffers.""" + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + self.assertIn("preterm", loss._buffers) + self.assertIn("bin_centers", loss._buffers) + self.assertFalse(loss.preterm.requires_grad) + self.assertFalse(loss.bin_centers.requires_grad) + self.assertEqual(loss.bin_centers.ndim, 3) + state = loss.state_dict() + self.assertNotIn("preterm", state) + self.assertNotIn("bin_centers", state) + + def test_bspline_kernel_has_no_gaussian_buffers(self): + """Verify b-spline kernel does not populate gaussian-specific buffers.""" + loss = GlobalMutualInformationLoss(kernel_type="b-spline") + self.assertIsNone(loss.preterm) + self.assertIsNone(loss.bin_centers) + state = loss.state_dict() + self.assertNotIn("preterm", state) + self.assertNotIn("bin_centers", state) + + def test_gaussian_kernel_forward_correct(self): + """Verify gaussian kernel forward pass returns a scalar loss tensor.""" + pred = torch.rand(2, 1, 8, 8, dtype=torch.float32) + target = torch.rand(2, 1, 8, 8, dtype=torch.float32) + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + result = loss(pred, target) + self.assertEqual(result.shape, torch.Size([])) + + def test_gaussian_buffers_move_with_module(self): + """Verify preterm and bin_centers buffers move to the target device with the module.""" + loss = GlobalMutualInformationLoss(kernel_type="gaussian") + self.assertEqual(loss.preterm.device.type, "cpu") + self.assertEqual(loss.bin_centers.device.type, "cpu") + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + loss = loss.cuda() + self.assertEqual(loss.preterm.device.type, "cuda") + self.assertEqual(loss.bin_centers.device.type, "cuda") + pred = torch.rand(2, 1, 8, 8, device="cuda") + target = torch.rand(2, 1, 8, 8, device="cuda") + result = loss(pred, target) + self.assertEqual(result.device.type, "cuda") + + if __name__ == "__main__": unittest.main()