From d3e1571b4d1ba7a4c9ceaa1a51a952e1ea753b4b Mon Sep 17 00:00:00 2001 From: ugbotueferhire Date: Sun, 31 May 2026 12:07:57 +0100 Subject: [PATCH] Register GlobalMutualInformationLoss bin_centers as buffer Signed-off-by: ugbotueferhire --- monai/losses/image_dissimilarity.py | 6 +++++- .../test_global_mutual_information_loss.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index de10e4b2d4..37c78fae60 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -233,9 +233,11 @@ def __init__( self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) self.num_bins = num_bins self.kernel_type = kernel_type + self.bin_centers: torch.Tensor | None + self.register_buffer("bin_centers", None, persistent=False) if self.kernel_type == "gaussian": self.preterm = 1 / (2 * sigma**2) - self.bin_centers = bin_centers[None, None, ...] + self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) @@ -314,6 +316,8 @@ def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, to """ img = torch.clamp(img, 0, 1) img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1) + if self.bin_centers is None: + raise ValueError("bin_centers must be defined for gaussian parzen windowing.") weight = torch.exp( -self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2 ) # (batch, num_sample, num_bin) diff --git a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py index ff7851ed1c..a16499ac11 100644 --- a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +++ b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py @@ -116,6 +116,25 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. class TestGlobalMutualInformationLossIll(unittest.TestCase): + def test_gaussian_bin_centers_registered_buffer(self): + loss = GlobalMutualInformationLoss(kernel_type="gaussian", num_bins=16) + + self.assertIn("bin_centers", dict(loss.named_buffers())) + self.assertIsNotNone(loss.bin_centers) + self.assertFalse(loss.bin_centers.requires_grad) + + loss = loss.to(dtype=torch.float64) + self.assertEqual(loss.bin_centers.dtype, torch.float64) + + if torch.cuda.is_available(): + loss = loss.to(device="cuda:0") + self.assertEqual(loss.bin_centers.device, torch.device("cuda:0")) + + def test_b_spline_bin_centers_exists_as_none(self): + loss = GlobalMutualInformationLoss(kernel_type="b-spline") + + self.assertIsNone(loss.bin_centers) + @parameterized.expand( [ (torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)), # mismatched_simple_dims