fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss#8818
fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss#8818Zeesejo wants to merge 1 commit intoProject-MONAI:devfrom
Conversation
…CrossCorrelationLoss Refactor kernel initialization to use register_buffer for better state management. Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com>
📝 WalkthroughWalkthroughThe change refactors Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/losses/image_dissimilarity.py`:
- Around line 114-115: Add unit tests that verify the new buffer semantics for
the ImageDissimilarity module: instantiate the module (using the class that
calls self.register_buffer for "kernel" and "kernel_vol"), assert both keys
appear in module.state_dict(), assert they are not in module.parameters() (i.e.,
non-parameter buffers), and test device movement by .to(device) or .cuda() to
confirm the tensors in module.kernel and module.kernel_vol move to the target
device; also include a test that the values persist across state_dict save/load
operations to ensure correct buffer behavior. Ensure you reference the
module/class name that defines get_kernel_vol and register_buffer when adding
the tests.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 9cc75dbd-673c-4854-8448-3356e29fa3a5
📒 Files selected for processing (1)
monai/losses/image_dissimilarity.py
| self.register_buffer("kernel", _kernel(self.kernel_size)) | ||
| self.register_buffer("kernel_vol", self.get_kernel_vol()) |
There was a problem hiding this comment.
Add regression tests for the new buffer semantics.
Line 114 and Line 115 change state behavior, but current tests only validate numerics/error paths. Please add assertions that kernel and kernel_vol are present in state_dict, remain non-parameter buffers, and move with module device changes.
As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/losses/image_dissimilarity.py` around lines 114 - 115, Add unit tests
that verify the new buffer semantics for the ImageDissimilarity module:
instantiate the module (using the class that calls self.register_buffer for
"kernel" and "kernel_vol"), assert both keys appear in module.state_dict(),
assert they are not in module.parameters() (i.e., non-parameter buffers), and
test device movement by .to(device) or .cuda() to confirm the tensors in
module.kernel and module.kernel_vol move to the target device; also include a
test that the values persist across state_dict save/load operations to ensure
correct buffer behavior. Ensure you reference the module/class name that defines
get_kernel_vol and register_buffer when adding the tests.
There was a problem hiding this comment.
Pull request overview
This PR fixes the initialization of constant kernel tensors in LocalNormalizedCrossCorrelationLoss by switching to the standard PyTorch nn.Module.register_buffer pattern to improve device placement and (optionally) serialization behavior.
Changes:
- Register
kernelas a module buffer instead of a plain attribute. - Register
kernel_volas a module buffer instead of a plain attribute.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| _kernel = look_up_option(kernel_type, kernel_dict) | ||
| self.kernel = _kernel(self.kernel_size) | ||
| self.kernel.require_grads = False | ||
| self.kernel_vol = self.get_kernel_vol() | ||
| self.register_buffer("kernel", _kernel(self.kernel_size)) | ||
| self.register_buffer("kernel_vol", self.get_kernel_vol()) |
There was a problem hiding this comment.
Registering kernel/kernel_vol as persistent buffers changes this module’s state_dict() schema (older checkpoints without these keys will fail load_state_dict(strict=True) with missing keys). Consider whether backward compatibility is required; if so, handle missing keys (e.g., custom _load_from_state_dict) and/or make derived kernel_vol non-persistent and recompute it from kernel on load.
| self.register_buffer("kernel", _kernel(self.kernel_size)) | ||
| self.register_buffer("kernel_vol", self.get_kernel_vol()) |
There was a problem hiding this comment.
This change is intended to improve device placement / serialization behavior, but there’s no regression test asserting (1) kernel/kernel_vol appear in state_dict() and/or (2) they follow .to(device/dtype) without per-forward casting. Adding a small unit test would help prevent this from regressing again.
Description
This PR fixes an incorrect kernel initialization pattern in
LocalNormalizedCrossCorrelationLossintroduced inmonai/losses/image_dissimilarity.py.Bug found: The original code used plain attribute assignment with a typo:
require_grads(with an 's') is not a valid PyTorch tensor attribute. The correct attribute isrequires_grad. This means the kernel was silently tracking gradients throughout every forward pass, wasting memory and computation.Furthermore, assigning the kernel as a plain attribute (
self.kernel = ...) rather than registering it as a buffer means:.to(device)/.cuda()/.half()is called on the loss modulestate_dict()/load_state_dict()Fix: Replace both assignments with
register_buffer, which is the correct PyTorchnn.Modulepattern for constant tensors:This removes the need for any manual
requires_gradmanagement and ensures proper device placement.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.Signed-off-by: Zeesejo 92383127+Zeesejo@users.noreply.github.com