fix(losses): register buffers in GlobalMutualInformationLoss#8872
fix(losses): register buffers in GlobalMutualInformationLoss#8872AlexanderSanin wants to merge 2 commits into
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughGlobalMutualInformationLoss now registers preterm and bin_centers as non-persistent buffers before the kernel-type dispatch and populates them when kernel_type == "gaussian". Tests were added to verify buffer registration and properties for the gaussian kernel, absence for b-spline, that a gaussian forward returns a scalar tensor, and that the gaussian buffers move with the module to CUDA. Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py (3)
158-161: ⚡ Quick winAdd docstring per coding guidelines.
📝 Suggested docstring
def test_bspline_kernel_has_no_gaussian_buffers(self): + """Verify b-spline kernel does not register gaussian-specific buffers.""" loss = GlobalMutualInformationLoss(kernel_type="b-spline")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py` around lines 158 - 161, The test function test_bspline_kernel_has_no_gaussian_buffers is missing a docstring; add a short descriptive docstring at the top of the function explaining that it verifies GlobalMutualInformationLoss(kernel_type="b-spline") does not populate Gaussian-specific buffers (specifically asserting "preterm" and "bin_centers" are not in loss._buffers). Keep it concise and follow existing test docstring style.
163-168: ⚡ Quick winAdd docstring per coding guidelines.
📝 Suggested docstring
def test_gaussian_kernel_forward_correct(self): + """Verify gaussian kernel forward pass returns scalar loss tensor.""" pred = torch.rand(2, 1, 8, 8, dtype=torch.float32)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py` around lines 163 - 168, Add a docstring to the unit test function test_gaussian_kernel_forward_correct that briefly describes what the test verifies (e.g., that GlobalMutualInformationLoss with kernel_type="gaussian" returns a scalar tensor and preserves shape), placing it directly under the def line in that function; reference the function name test_gaussian_kernel_forward_correct and the class/constructor GlobalMutualInformationLoss(kernel_type="gaussian") so reviewers can locate and confirm the new docstring.
149-156: ⚡ Quick winAdd docstring per coding guidelines.
Docstrings required for all test methods describing purpose and expectations. As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
📝 Suggested docstring
def test_gaussian_kernel_registers_buffers(self): + """Verify gaussian kernel registers preterm and bin_centers as non-trainable buffers.""" loss = GlobalMutualInformationLoss(kernel_type="gaussian")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py` around lines 149 - 156, Add a Google-style docstring to the test method test_gaussian_kernel_registers_buffers describing what is being tested (that GlobalMutualInformationLoss with kernel_type="gaussian" registers preterm and bin_centers as non-trainable buffers, that they move with .to(), and that bin_centers has ndim == 3), including a short "Args" if needed and an "Expected" or "Raises" note for the assertions; update the docstring inside the test function definition (test_gaussian_kernel_registers_buffers) so it clearly states the purpose and the expected conditions checked by the assertions.monai/losses/image_dissimilarity.py (1)
236-237: 💤 Low valueType annotations declared unconditionally but attributes are conditionally assigned.
These annotations are defined outside the gaussian conditional block, but the actual attributes are only created when
kernel_type == "gaussian". While runtime behavior is correct (attributes only accessed in gaussian path), static type checkers may flag potentialAttributeErrorfor b-spline mode.Consider either:
- Moving annotations inside the conditional, or
- Initializing to
Noneand usingOptional[torch.Tensor]type🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/losses/image_dissimilarity.py` around lines 236 - 237, The attributes self.preterm and self.bin_centers are only created when kernel_type == "gaussian" but currently annotated unconditionally; update their declarations to reflect conditional creation by typing them as Optional[torch.Tensor] and initialize them to None in the non-gaussian branch (or before the conditional) so static type checkers know they may be absent, and ensure any gaussian-only use sites (e.g., inside the gaussian branch) treat them as non-None; reference the attributes self.preterm, self.bin_centers and the kernel_type == "gaussian" conditional when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 148-169: Add a test that verifies Gaussian kernel buffers actually
move with the module: instantiate
GlobalMutualInformationLoss(kernel_type="gaussian"), if
torch.cuda.is_available() call loss_cuda = loss.to("cuda") (or loss.cuda()),
then assert loss_cuda.preterm.device.type == "cuda" and
loss_cuda.bin_centers.device.type == "cuda", create CUDA tensors for pred and
target and run result = loss_cuda(pred, target) and assert result.device.type ==
"cuda"; reference the GlobalMutualInformationLoss class and its buffers preterm
and bin_centers and add this as a new test method (e.g.,
test_gaussian_buffers_move_with_module) alongside the existing tests.
---
Nitpick comments:
In `@monai/losses/image_dissimilarity.py`:
- Around line 236-237: The attributes self.preterm and self.bin_centers are only
created when kernel_type == "gaussian" but currently annotated unconditionally;
update their declarations to reflect conditional creation by typing them as
Optional[torch.Tensor] and initialize them to None in the non-gaussian branch
(or before the conditional) so static type checkers know they may be absent, and
ensure any gaussian-only use sites (e.g., inside the gaussian branch) treat them
as non-None; reference the attributes self.preterm, self.bin_centers and the
kernel_type == "gaussian" conditional when making the change.
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 158-161: The test function
test_bspline_kernel_has_no_gaussian_buffers is missing a docstring; add a short
descriptive docstring at the top of the function explaining that it verifies
GlobalMutualInformationLoss(kernel_type="b-spline") does not populate
Gaussian-specific buffers (specifically asserting "preterm" and "bin_centers"
are not in loss._buffers). Keep it concise and follow existing test docstring
style.
- Around line 163-168: Add a docstring to the unit test function
test_gaussian_kernel_forward_correct that briefly describes what the test
verifies (e.g., that GlobalMutualInformationLoss with kernel_type="gaussian"
returns a scalar tensor and preserves shape), placing it directly under the def
line in that function; reference the function name
test_gaussian_kernel_forward_correct and the class/constructor
GlobalMutualInformationLoss(kernel_type="gaussian") so reviewers can locate and
confirm the new docstring.
- Around line 149-156: Add a Google-style docstring to the test method
test_gaussian_kernel_registers_buffers describing what is being tested (that
GlobalMutualInformationLoss with kernel_type="gaussian" registers preterm and
bin_centers as non-trainable buffers, that they move with .to(), and that
bin_centers has ndim == 3), including a short "Args" if needed and an "Expected"
or "Raises" note for the assertions; update the docstring inside the test
function definition (test_gaussian_kernel_registers_buffers) so it clearly
states the purpose and the expected conditions checked by the assertions.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 22b238ee-90d5-4b39-9f22-3df62dfea05d
📒 Files selected for processing (2)
monai/losses/image_dissimilarity.pytests/losses/image_dissimilarity/test_global_mutual_information_loss.py
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py (2)
149-185: ⚡ Quick winUse full Google-style docstrings for new test methods.
Current one-line docstrings don’t include the required sections (
Args,Returns,Raises) from the repo guideline.As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py` around lines 149 - 185, The tests use one-line docstrings; update each test's docstring (test_gaussian_kernel_registers_buffers, test_bspline_kernel_has_no_gaussian_buffers, test_gaussian_kernel_forward_correct, test_gaussian_buffers_move_with_module) to full Google-style docstrings that include a short summary plus Args (describe pred/target shapes or when the test constructs the loss), Returns (what the test asserts, e.g., None or scalar loss), and Raises (any expected exceptions, if none state "None"); keep the existing descriptive text as the summary and add the three sections to meet the repo guideline.
149-163: ⚡ Quick winAssert non-persistent buffer contract explicitly.
Please also verify
pretermandbin_centersare excluded fromstate_dict()to lock inpersistent=Falsebehavior.Proposed test additions
def test_gaussian_kernel_registers_buffers(self): """preterm and bin_centers are registered as non-persistent buffers for gaussian kernel.""" loss = GlobalMutualInformationLoss(kernel_type="gaussian") self.assertIn("preterm", loss._buffers) self.assertIn("bin_centers", loss._buffers) self.assertFalse(loss.preterm.requires_grad) self.assertFalse(loss.bin_centers.requires_grad) self.assertEqual(loss.bin_centers.ndim, 3) + state = loss.state_dict() + self.assertNotIn("preterm", state) + self.assertNotIn("bin_centers", state) def test_bspline_kernel_has_no_gaussian_buffers(self): """b-spline kernel does not register gaussian-specific buffers.""" loss = GlobalMutualInformationLoss(kernel_type="b-spline") self.assertIsNone(loss.preterm) self.assertIsNone(loss.bin_centers) + state = loss.state_dict() + self.assertNotIn("preterm", state) + self.assertNotIn("bin_centers", state)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py` around lines 149 - 163, Update the tests for GlobalMutualInformationLoss to assert the non-persistent buffer contract by checking that gaussian-specific buffers do not appear in the module state dict: in test_gaussian_kernel_registers_buffers (for kernel_type="gaussian") after asserting preterm and bin_centers exist and have correct properties, also call loss.state_dict() and assert "preterm" and "bin_centers" are not keys; similarly, in test_bspline_kernel_has_no_gaussian_buffers (for kernel_type="b-spline") confirm state_dict() also does not contain those keys (and that loss.preterm and loss.bin_centers remain None) so persistent=False behavior is enforced.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 149-185: The tests use one-line docstrings; update each test's
docstring (test_gaussian_kernel_registers_buffers,
test_bspline_kernel_has_no_gaussian_buffers,
test_gaussian_kernel_forward_correct, test_gaussian_buffers_move_with_module) to
full Google-style docstrings that include a short summary plus Args (describe
pred/target shapes or when the test constructs the loss), Returns (what the test
asserts, e.g., None or scalar loss), and Raises (any expected exceptions, if
none state "None"); keep the existing descriptive text as the summary and add
the three sections to meet the repo guideline.
- Around line 149-163: Update the tests for GlobalMutualInformationLoss to
assert the non-persistent buffer contract by checking that gaussian-specific
buffers do not appear in the module state dict: in
test_gaussian_kernel_registers_buffers (for kernel_type="gaussian") after
asserting preterm and bin_centers exist and have correct properties, also call
loss.state_dict() and assert "preterm" and "bin_centers" are not keys;
similarly, in test_bspline_kernel_has_no_gaussian_buffers (for
kernel_type="b-spline") confirm state_dict() also does not contain those keys
(and that loss.preterm and loss.bin_centers remain None) so persistent=False
behavior is enforced.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 13a2592f-5cd6-4641-a30a-def8e3d5df80
📒 Files selected for processing (2)
monai/losses/image_dissimilarity.pytests/losses/image_dissimilarity/test_global_mutual_information_loss.py
|
Happy to go through this. Any idea why the CI is failing? |
When kernel_type="gaussian", `preterm` and `bin_centers` were stored
as plain tensor attributes via simple assignment. This means they are
not registered in PyTorch's module buffer system, so calling
`loss.to("cuda")` or `loss.cuda()` does not move these tensors to the
target device. Each forward pass had to call `.to(img)` to patch the
device mismatch at runtime, which is both redundant and misleading.
Use `register_buffer(..., persistent=False)` so that both tensors are
properly tracked by the module and automatically move with `.to()` /
`.cuda()` / `.cpu()` calls, consistent with the pattern already used
by `LocalNormalizedCrossCorrelationLoss`.
The `.to(img)` calls in `parzen_windowing_gaussian` are retained for
dtype coercion (e.g. float16 inference).
Adds `TestGlobalMutualInformationLossBuffers` to verify buffer
registration and that b-spline mode does not create gaussian buffers.
Closes Project-MONAI#8819
Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
…uffer fix
Use register_buffer("preterm", None) / register_buffer("bin_centers", None)
unconditionally so that both buffers are always present in _buffers (with None
for b-spline). This avoids a KeyError that occurred when plain instance
attribute assignment conflicted with a subsequent register_buffer call.
Also add docstrings to the new test methods and a device-movement test that
verifies buffers follow the module when .cuda() is called.
Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
20c702a to
379b8a8
Compare
|
Hi @aymuos15! The CI failures are pre-existing on the The root cause is that You can verify by checking the most recent All tests specific to this PR pass in CI (confirmed in the I've also rebased the branch on the latest |
|
Okay! Thanks a lot for the detailed reply. Let's wait for that to get an upstream fix then? I think that is being tracked and worked on in real time. |
Status & BlockersAll code changes are complete and tests pass. Two things are blocking merge: 1. CI Failures (external regression — not our code)The failing CI jobs ( The failure chain is: This is pre-existing on the 2. Awaiting Required ReviewThis PR needs at least one approving review from a code owner before it can be merged. @KumoLiu @ericspod @Nic-Ma — could one of you take a look when you get a chance? The change is in |
|
Hey, thanks for the fix there. But I think that is a deeper fix and one the maintainers should do. Will get to both when they are done. |
Summary
GlobalMutualInformationLossstoredpretermandbin_centersas plain tensor attributes whenkernel_type="gaussian", so callingloss.to("cuda")orloss.cuda()did not move them to the target deviceregister_buffer(..., persistent=False), consistent with the pattern already applied toLocalNormalizedCrossCorrelationLossin fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss #8818.to(img)calls inparzen_windowing_gaussianare retained for dtype coercion (e.g. float16 inference)Test plan
python -m pytest tests/losses/image_dissimilarity/test_global_mutual_information_loss.py -v— all existing tests still passTestGlobalMutualInformationLossBuffers::test_gaussian_kernel_registers_buffers—pretermandbin_centersare in_buffersand haverequires_grad=FalseTestGlobalMutualInformationLossBuffers::test_bspline_kernel_has_no_gaussian_buffers— b-spline mode is unaffectedTestGlobalMutualInformationLossBuffers::test_gaussian_kernel_forward_correct— forward pass returns a scalar lossCloses #8819