Skip to content

fix(losses): register buffers in GlobalMutualInformationLoss#8872

Open
AlexanderSanin wants to merge 2 commits into
Project-MONAI:devfrom
AlexanderSanin:fix/global-mutual-information-register-buffer
Open

fix(losses): register buffers in GlobalMutualInformationLoss#8872
AlexanderSanin wants to merge 2 commits into
Project-MONAI:devfrom
AlexanderSanin:fix/global-mutual-information-register-buffer

Conversation

@AlexanderSanin
Copy link
Copy Markdown
Contributor

Summary

  • GlobalMutualInformationLoss stored preterm and bin_centers as plain tensor attributes when kernel_type="gaussian", so calling loss.to("cuda") or loss.cuda() did not move them to the target device
  • Replace the plain assignments with register_buffer(..., persistent=False), consistent with the pattern already applied to LocalNormalizedCrossCorrelationLoss in fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss #8818
  • The .to(img) calls in parzen_windowing_gaussian are retained for dtype coercion (e.g. float16 inference)

Test plan

  • python -m pytest tests/losses/image_dissimilarity/test_global_mutual_information_loss.py -v — all existing tests still pass
  • TestGlobalMutualInformationLossBuffers::test_gaussian_kernel_registers_bufferspreterm and bin_centers are in _buffers and have requires_grad=False
  • TestGlobalMutualInformationLossBuffers::test_bspline_kernel_has_no_gaussian_buffers — b-spline mode is unaffected
  • TestGlobalMutualInformationLossBuffers::test_gaussian_kernel_forward_correct — forward pass returns a scalar loss

Closes #8819

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Hey @ericspod @aymuos15. Could you, please, have a look at this?

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 25, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 60e193db-900b-46fa-8c6b-1ca9d39a50c0

📥 Commits

Reviewing files that changed from the base of the PR and between 20c702a and 379b8a8.

📒 Files selected for processing (2)
  • monai/losses/image_dissimilarity.py
  • tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

📝 Walkthrough

Walkthrough

GlobalMutualInformationLoss now registers preterm and bin_centers as non-persistent buffers before the kernel-type dispatch and populates them when kernel_type == "gaussian". Tests were added to verify buffer registration and properties for the gaussian kernel, absence for b-spline, that a gaussian forward returns a scalar tensor, and that the gaussian buffers move with the module to CUDA.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Title accurately summarizes the main change: registering buffers in GlobalMutualInformationLoss to fix device placement.
Description check ✅ Passed Description provides context (related pattern from #8818), test plan with specific test names, and issue closure link, but lacks 'Fixes # 8819' at the top and some template sections.
Linked Issues check ✅ Passed Changes fully satisfy #8819 objectives: register preterm and bin_centers as non-persistent buffers so they move with module device operations and have requires_grad=False.
Out of Scope Changes check ✅ Passed All changes are in-scope: buffer registration in GlobalMutualInformationLoss and comprehensive test coverage for buffer setup and device movement.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py (3)

158-161: ⚡ Quick win

Add docstring per coding guidelines.

📝 Suggested docstring
 def test_bspline_kernel_has_no_gaussian_buffers(self):
+    """Verify b-spline kernel does not register gaussian-specific buffers."""
     loss = GlobalMutualInformationLoss(kernel_type="b-spline")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 158 - 161, The test function
test_bspline_kernel_has_no_gaussian_buffers is missing a docstring; add a short
descriptive docstring at the top of the function explaining that it verifies
GlobalMutualInformationLoss(kernel_type="b-spline") does not populate
Gaussian-specific buffers (specifically asserting "preterm" and "bin_centers"
are not in loss._buffers). Keep it concise and follow existing test docstring
style.

163-168: ⚡ Quick win

Add docstring per coding guidelines.

📝 Suggested docstring
 def test_gaussian_kernel_forward_correct(self):
+    """Verify gaussian kernel forward pass returns scalar loss tensor."""
     pred = torch.rand(2, 1, 8, 8, dtype=torch.float32)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 163 - 168, Add a docstring to the unit test function
test_gaussian_kernel_forward_correct that briefly describes what the test
verifies (e.g., that GlobalMutualInformationLoss with kernel_type="gaussian"
returns a scalar tensor and preserves shape), placing it directly under the def
line in that function; reference the function name
test_gaussian_kernel_forward_correct and the class/constructor
GlobalMutualInformationLoss(kernel_type="gaussian") so reviewers can locate and
confirm the new docstring.

149-156: ⚡ Quick win

Add docstring per coding guidelines.

Docstrings required for all test methods describing purpose and expectations. As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

📝 Suggested docstring
 def test_gaussian_kernel_registers_buffers(self):
+    """Verify gaussian kernel registers preterm and bin_centers as non-trainable buffers."""
     loss = GlobalMutualInformationLoss(kernel_type="gaussian")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 149 - 156, Add a Google-style docstring to the test method
test_gaussian_kernel_registers_buffers describing what is being tested (that
GlobalMutualInformationLoss with kernel_type="gaussian" registers preterm and
bin_centers as non-trainable buffers, that they move with .to(), and that
bin_centers has ndim == 3), including a short "Args" if needed and an "Expected"
or "Raises" note for the assertions; update the docstring inside the test
function definition (test_gaussian_kernel_registers_buffers) so it clearly
states the purpose and the expected conditions checked by the assertions.
monai/losses/image_dissimilarity.py (1)

236-237: 💤 Low value

Type annotations declared unconditionally but attributes are conditionally assigned.

These annotations are defined outside the gaussian conditional block, but the actual attributes are only created when kernel_type == "gaussian". While runtime behavior is correct (attributes only accessed in gaussian path), static type checkers may flag potential AttributeError for b-spline mode.

Consider either:

  • Moving annotations inside the conditional, or
  • Initializing to None and using Optional[torch.Tensor] type
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/image_dissimilarity.py` around lines 236 - 237, The attributes
self.preterm and self.bin_centers are only created when kernel_type ==
"gaussian" but currently annotated unconditionally; update their declarations to
reflect conditional creation by typing them as Optional[torch.Tensor] and
initialize them to None in the non-gaussian branch (or before the conditional)
so static type checkers know they may be absent, and ensure any gaussian-only
use sites (e.g., inside the gaussian branch) treat them as non-None; reference
the attributes self.preterm, self.bin_centers and the kernel_type == "gaussian"
conditional when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 148-169: Add a test that verifies Gaussian kernel buffers actually
move with the module: instantiate
GlobalMutualInformationLoss(kernel_type="gaussian"), if
torch.cuda.is_available() call loss_cuda = loss.to("cuda") (or loss.cuda()),
then assert loss_cuda.preterm.device.type == "cuda" and
loss_cuda.bin_centers.device.type == "cuda", create CUDA tensors for pred and
target and run result = loss_cuda(pred, target) and assert result.device.type ==
"cuda"; reference the GlobalMutualInformationLoss class and its buffers preterm
and bin_centers and add this as a new test method (e.g.,
test_gaussian_buffers_move_with_module) alongside the existing tests.

---

Nitpick comments:
In `@monai/losses/image_dissimilarity.py`:
- Around line 236-237: The attributes self.preterm and self.bin_centers are only
created when kernel_type == "gaussian" but currently annotated unconditionally;
update their declarations to reflect conditional creation by typing them as
Optional[torch.Tensor] and initialize them to None in the non-gaussian branch
(or before the conditional) so static type checkers know they may be absent, and
ensure any gaussian-only use sites (e.g., inside the gaussian branch) treat them
as non-None; reference the attributes self.preterm, self.bin_centers and the
kernel_type == "gaussian" conditional when making the change.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 158-161: The test function
test_bspline_kernel_has_no_gaussian_buffers is missing a docstring; add a short
descriptive docstring at the top of the function explaining that it verifies
GlobalMutualInformationLoss(kernel_type="b-spline") does not populate
Gaussian-specific buffers (specifically asserting "preterm" and "bin_centers"
are not in loss._buffers). Keep it concise and follow existing test docstring
style.
- Around line 163-168: Add a docstring to the unit test function
test_gaussian_kernel_forward_correct that briefly describes what the test
verifies (e.g., that GlobalMutualInformationLoss with kernel_type="gaussian"
returns a scalar tensor and preserves shape), placing it directly under the def
line in that function; reference the function name
test_gaussian_kernel_forward_correct and the class/constructor
GlobalMutualInformationLoss(kernel_type="gaussian") so reviewers can locate and
confirm the new docstring.
- Around line 149-156: Add a Google-style docstring to the test method
test_gaussian_kernel_registers_buffers describing what is being tested (that
GlobalMutualInformationLoss with kernel_type="gaussian" registers preterm and
bin_centers as non-trainable buffers, that they move with .to(), and that
bin_centers has ndim == 3), including a short "Args" if needed and an "Expected"
or "Raises" note for the assertions; update the docstring inside the test
function definition (test_gaussian_kernel_registers_buffers) so it clearly
states the purpose and the expected conditions checked by the assertions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 22b238ee-90d5-4b39-9f22-3df62dfea05d

📥 Commits

Reviewing files that changed from the base of the PR and between 0a8d945 and f20d3f6.

📒 Files selected for processing (2)
  • monai/losses/image_dissimilarity.py
  • tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py (2)

149-185: ⚡ Quick win

Use full Google-style docstrings for new test methods.

Current one-line docstrings don’t include the required sections (Args, Returns, Raises) from the repo guideline.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 149 - 185, The tests use one-line docstrings; update each test's
docstring (test_gaussian_kernel_registers_buffers,
test_bspline_kernel_has_no_gaussian_buffers,
test_gaussian_kernel_forward_correct, test_gaussian_buffers_move_with_module) to
full Google-style docstrings that include a short summary plus Args (describe
pred/target shapes or when the test constructs the loss), Returns (what the test
asserts, e.g., None or scalar loss), and Raises (any expected exceptions, if
none state "None"); keep the existing descriptive text as the summary and add
the three sections to meet the repo guideline.

149-163: ⚡ Quick win

Assert non-persistent buffer contract explicitly.

Please also verify preterm and bin_centers are excluded from state_dict() to lock in persistent=False behavior.

Proposed test additions
 def test_gaussian_kernel_registers_buffers(self):
     """preterm and bin_centers are registered as non-persistent buffers for gaussian kernel."""
     loss = GlobalMutualInformationLoss(kernel_type="gaussian")
     self.assertIn("preterm", loss._buffers)
     self.assertIn("bin_centers", loss._buffers)
     self.assertFalse(loss.preterm.requires_grad)
     self.assertFalse(loss.bin_centers.requires_grad)
     self.assertEqual(loss.bin_centers.ndim, 3)
+    state = loss.state_dict()
+    self.assertNotIn("preterm", state)
+    self.assertNotIn("bin_centers", state)

 def test_bspline_kernel_has_no_gaussian_buffers(self):
     """b-spline kernel does not register gaussian-specific buffers."""
     loss = GlobalMutualInformationLoss(kernel_type="b-spline")
     self.assertIsNone(loss.preterm)
     self.assertIsNone(loss.bin_centers)
+    state = loss.state_dict()
+    self.assertNotIn("preterm", state)
+    self.assertNotIn("bin_centers", state)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 149 - 163, Update the tests for GlobalMutualInformationLoss to
assert the non-persistent buffer contract by checking that gaussian-specific
buffers do not appear in the module state dict: in
test_gaussian_kernel_registers_buffers (for kernel_type="gaussian") after
asserting preterm and bin_centers exist and have correct properties, also call
loss.state_dict() and assert "preterm" and "bin_centers" are not keys;
similarly, in test_bspline_kernel_has_no_gaussian_buffers (for
kernel_type="b-spline") confirm state_dict() also does not contain those keys
(and that loss.preterm and loss.bin_centers remain None) so persistent=False
behavior is enforced.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 149-185: The tests use one-line docstrings; update each test's
docstring (test_gaussian_kernel_registers_buffers,
test_bspline_kernel_has_no_gaussian_buffers,
test_gaussian_kernel_forward_correct, test_gaussian_buffers_move_with_module) to
full Google-style docstrings that include a short summary plus Args (describe
pred/target shapes or when the test constructs the loss), Returns (what the test
asserts, e.g., None or scalar loss), and Raises (any expected exceptions, if
none state "None"); keep the existing descriptive text as the summary and add
the three sections to meet the repo guideline.
- Around line 149-163: Update the tests for GlobalMutualInformationLoss to
assert the non-persistent buffer contract by checking that gaussian-specific
buffers do not appear in the module state dict: in
test_gaussian_kernel_registers_buffers (for kernel_type="gaussian") after
asserting preterm and bin_centers exist and have correct properties, also call
loss.state_dict() and assert "preterm" and "bin_centers" are not keys;
similarly, in test_bspline_kernel_has_no_gaussian_buffers (for
kernel_type="b-spline") confirm state_dict() also does not contain those keys
(and that loss.preterm and loss.bin_centers remain None) so persistent=False
behavior is enforced.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 13a2592f-5cd6-4641-a30a-def8e3d5df80

📥 Commits

Reviewing files that changed from the base of the PR and between f20d3f6 and 20c702a.

📒 Files selected for processing (2)
  • monai/losses/image_dissimilarity.py
  • tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

@aymuos15
Copy link
Copy Markdown
Contributor

Happy to go through this. Any idea why the CI is failing?

When kernel_type="gaussian", `preterm` and `bin_centers` were stored
as plain tensor attributes via simple assignment. This means they are
not registered in PyTorch's module buffer system, so calling
`loss.to("cuda")` or `loss.cuda()` does not move these tensors to the
target device. Each forward pass had to call `.to(img)` to patch the
device mismatch at runtime, which is both redundant and misleading.

Use `register_buffer(..., persistent=False)` so that both tensors are
properly tracked by the module and automatically move with `.to()` /
`.cuda()` / `.cpu()` calls, consistent with the pattern already used
by `LocalNormalizedCrossCorrelationLoss`.

The `.to(img)` calls in `parzen_windowing_gaussian` are retained for
dtype coercion (e.g. float16 inference).

Adds `TestGlobalMutualInformationLossBuffers` to verify buffer
registration and that b-spline mode does not create gaussian buffers.

Closes Project-MONAI#8819

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
…uffer fix

Use register_buffer("preterm", None) / register_buffer("bin_centers", None)
unconditionally so that both buffers are always present in _buffers (with None
for b-spline). This avoids a KeyError that occurred when plain instance
attribute assignment conflicted with a subsequent register_buffer call.

Also add docstrings to the new test methods and a device-movement test that
verifies buffers follow the module when .cuda() is called.

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
@AlexanderSanin AlexanderSanin force-pushed the fix/global-mutual-information-register-buffer branch from 20c702a to 379b8a8 Compare May 26, 2026 07:06
@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Hi @aymuos15! The CI failures are pre-existing on the dev branch itself and are not related to this PR.

The root cause is that monai/networks/trt_compiler.py imports polygraphy.backend.common, which transitively imports cupy.testing._random, which does import pytest. In the packaging / full-dep (ubuntu-latest) CI jobs, pytest is not installed in the target environment (it's only in requirements-dev.txt), so any import of monai.networks triggers a ModuleNotFoundError and causes cascade failures across ~6759 tests.

You can verify by checking the most recent dev branch CI run — it shows the exact same FAILED (failures=17, errors=6759, skipped=1020) result on the same jobs.

All tests specific to this PR pass in CI (confirmed in the packaging job log):

test_bspline_kernel_has_no_gaussian_buffers ... ok
test_gaussian_buffers_move_with_module      ... ok
test_gaussian_kernel_forward_correct        ... ok
test_gaussian_kernel_registers_buffers      ... ok
test_ill_opts_{0..3}                        ... ok
test_ill_shape_{0..1}                       ... ok

I've also rebased the branch on the latest dev just now.

@aymuos15
Copy link
Copy Markdown
Contributor

aymuos15 commented May 26, 2026

Okay! Thanks a lot for the detailed reply.

Let's wait for that to get an upstream fix then? I think that is being tracked and worked on in real time.

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Okay! Thanks a lot for the detailed reply.

Let's wait for that to get an upstream fix then? I think that is being tracked and worked on in real time.

Hi @aymuos15 ,

I've addressed that in another PR: #8873
Please review , thanks

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

AlexanderSanin commented May 26, 2026

Status & Blockers

All code changes are complete and tests pass. Two things are blocking merge:

1. CI Failures (external regression — not our code)

The failing CI jobs (full-dep (ubuntu-latest), static-checks (pytype), mypy, codeformat) are caused by a cupy-cuda12x 14.1.0 regression released on May 26, 2026. That version introduced import pytest at module load time in cupy/testing/_random.py, which breaks any environment where pytest is not installed.

The failure chain is:
monai.networkstrt_compilerpolygraphy.backend.commonpolygraphy.util.utilcupy.testing._randomimport pytestModuleNotFoundError

This is pre-existing on the dev branch — the last fully green dev CI run (#26286623163) was on May 22 with cupy 14.0.1. All our PR-specific tests pass in the packaging job logs.

2. Awaiting Required Review

This PR needs at least one approving review from a code owner before it can be merged.

@KumoLiu @ericspod @Nic-Ma — could one of you take a look when you get a chance? The change is in monai/losses/image_dissimilarity.py and registers preterm / bin_centers as non-persistent buffers so they move correctly with .to(device) / .cuda(). Thanks!

@aymuos15
Copy link
Copy Markdown
Contributor

Hey, thanks for the fix there. But I think that is a deeper fix and one the maintainers should do. Will get to both when they are done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] LocalNormalizedCrossCorrelationLoss: kernel not registered as buffer — silent gradient tracking + wrong device placement

2 participants