Skip to content

fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss#8818

Open
Zeesejo wants to merge 1 commit intoProject-MONAI:devfrom
Zeesejo:fix/lncc-register-buffer-kernel
Open

fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss#8818
Zeesejo wants to merge 1 commit intoProject-MONAI:devfrom
Zeesejo:fix/lncc-register-buffer-kernel

Conversation

@Zeesejo
Copy link
Copy Markdown

@Zeesejo Zeesejo commented Apr 11, 2026

Description

This PR fixes an incorrect kernel initialization pattern in LocalNormalizedCrossCorrelationLoss introduced in monai/losses/image_dissimilarity.py.

Bug found: The original code used plain attribute assignment with a typo:

self.kernel = _kernel(self.kernel_size)
self.kernel.require_grads = False  # typo: 'require_grads' is NOT a valid tensor attribute
self.kernel_vol = self.get_kernel_vol()

require_grads (with an 's') is not a valid PyTorch tensor attribute. The correct attribute is requires_grad. This means the kernel was silently tracking gradients throughout every forward pass, wasting memory and computation.

Furthermore, assigning the kernel as a plain attribute (self.kernel = ...) rather than registering it as a buffer means:

  • The kernel tensor does not automatically move to the correct device when .to(device) / .cuda() / .half() is called on the loss module
  • The kernel is not saved/restored correctly via state_dict() / load_state_dict()

Fix: Replace both assignments with register_buffer, which is the correct PyTorch nn.Module pattern for constant tensors:

self.register_buffer("kernel", _kernel(self.kernel_size))
self.register_buffer("kernel_vol", self.get_kernel_vol())

This removes the need for any manual requires_grad management and ensures proper device placement.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Signed-off-by: Zeesejo 92383127+Zeesejo@users.noreply.github.com

…CrossCorrelationLoss

Refactor kernel initialization to use register_buffer for better state management.

Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com>
Copilot AI review requested due to automatic review settings April 11, 2026 17:47
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 11, 2026

📝 Walkthrough

Walkthrough

The change refactors LocalNormalizedCrossCorrelationLoss.__init__ to register the precomputed kernel tensor and its scalar volume as PyTorch buffers using register_buffer() instead of assigning them as plain tensor attributes. This integrates the tensors with module state management (device moves, state dict serialization) while maintaining the existing forward method usage pattern with .to(pred) calls.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~3 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and concisely describes the main change: using register_buffer for kernel tensors in LocalNormalizedCrossCorrelationLoss.
Description check ✅ Passed Description provides thorough context on the bug (typo in require_grads), explains consequences, and details the fix. Template sections are completed with appropriate checkboxes marked.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/losses/image_dissimilarity.py`:
- Around line 114-115: Add unit tests that verify the new buffer semantics for
the ImageDissimilarity module: instantiate the module (using the class that
calls self.register_buffer for "kernel" and "kernel_vol"), assert both keys
appear in module.state_dict(), assert they are not in module.parameters() (i.e.,
non-parameter buffers), and test device movement by .to(device) or .cuda() to
confirm the tensors in module.kernel and module.kernel_vol move to the target
device; also include a test that the values persist across state_dict save/load
operations to ensure correct buffer behavior. Ensure you reference the
module/class name that defines get_kernel_vol and register_buffer when adding
the tests.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9cc75dbd-673c-4854-8448-3356e29fa3a5

📥 Commits

Reviewing files that changed from the base of the PR and between cc92126 and 6df895b.

📒 Files selected for processing (1)
  • monai/losses/image_dissimilarity.py

Comment on lines +114 to +115
self.register_buffer("kernel", _kernel(self.kernel_size))
self.register_buffer("kernel_vol", self.get_kernel_vol())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add regression tests for the new buffer semantics.

Line 114 and Line 115 change state behavior, but current tests only validate numerics/error paths. Please add assertions that kernel and kernel_vol are present in state_dict, remain non-parameter buffers, and move with module device changes.

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/image_dissimilarity.py` around lines 114 - 115, Add unit tests
that verify the new buffer semantics for the ImageDissimilarity module:
instantiate the module (using the class that calls self.register_buffer for
"kernel" and "kernel_vol"), assert both keys appear in module.state_dict(),
assert they are not in module.parameters() (i.e., non-parameter buffers), and
test device movement by .to(device) or .cuda() to confirm the tensors in
module.kernel and module.kernel_vol move to the target device; also include a
test that the values persist across state_dict save/load operations to ensure
correct buffer behavior. Ensure you reference the module/class name that defines
get_kernel_vol and register_buffer when adding the tests.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes the initialization of constant kernel tensors in LocalNormalizedCrossCorrelationLoss by switching to the standard PyTorch nn.Module.register_buffer pattern to improve device placement and (optionally) serialization behavior.

Changes:

  • Register kernel as a module buffer instead of a plain attribute.
  • Register kernel_vol as a module buffer instead of a plain attribute.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 113 to +115
_kernel = look_up_option(kernel_type, kernel_dict)
self.kernel = _kernel(self.kernel_size)
self.kernel.require_grads = False
self.kernel_vol = self.get_kernel_vol()
self.register_buffer("kernel", _kernel(self.kernel_size))
self.register_buffer("kernel_vol", self.get_kernel_vol())
Copy link

Copilot AI Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registering kernel/kernel_vol as persistent buffers changes this module’s state_dict() schema (older checkpoints without these keys will fail load_state_dict(strict=True) with missing keys). Consider whether backward compatibility is required; if so, handle missing keys (e.g., custom _load_from_state_dict) and/or make derived kernel_vol non-persistent and recompute it from kernel on load.

Copilot uses AI. Check for mistakes.
Comment on lines +114 to +115
self.register_buffer("kernel", _kernel(self.kernel_size))
self.register_buffer("kernel_vol", self.get_kernel_vol())
Copy link

Copilot AI Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is intended to improve device placement / serialization behavior, but there’s no regression test asserting (1) kernel/kernel_vol appear in state_dict() and/or (2) they follow .to(device/dtype) without per-forward casting. Adding a small unit test would help prevent this from regressing again.

Copilot uses AI. Check for mistakes.
@Zeesejo
Copy link
Copy Markdown
Author

Zeesejo commented Apr 11, 2026

Related issue documenting the full bug analysis and 1→10→100 chain of improvements: #8819

The issue also identifies two follow-up areas to audit:

  1. GlobalMutualInformationLossbin_centers should also be a registered buffer
  2. Broader test coverage for device movement of all loss modules

Fixes #8819

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants