Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions monai/losses/cldice.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor:
return skel


def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor:
def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0, smooth_dr: float = 1e-7) -> torch.Tensor:
"""
Function to compute soft dice loss

Expand All @@ -102,12 +102,17 @@ def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -
Args:
y_true: the shape should be BCH(WD)
y_pred: the shape should be BCH(WD)
smooth: Smoothing parameter for numerator and denominator.
smooth_dr: Small epsilon added to the denominator to prevent 0/0 when smooth=0
and inputs (or skeletonizations) are entirely zero. Defaults to 1e-7.

Returns:
dice loss
"""
intersection = torch.sum((y_true * y_pred)[:, 1:, ...])
coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth)
coeff = (2.0 * intersection + smooth) / (
torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth + smooth_dr
)
soft_dice: torch.Tensor = 1.0 - coeff
return soft_dice

Expand All @@ -123,26 +128,30 @@ class SoftclDiceLoss(_Loss):
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7
"""

def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None:
def __init__(self, iter_: int = 3, smooth: float = 1.0, smooth_dr: float = 1e-7) -> None:
"""
Args:
iter_: Number of iterations for skeletonization. Defaults to 3.
smooth: Smoothing parameter. Defaults to 1.0.
smooth_dr: Small epsilon added to each ratio denominator and to the
harmonic mean denominator to prevent 0/0 when smooth=0 and the
skeleton is entirely zero. Defaults to 1e-7.
"""
super().__init__()
self.iter = iter_
self.smooth = smooth
self.smooth_dr = smooth_dr

def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
torch.sum(skel_pred[:, 1:, ...]) + self.smooth + self.smooth_dr
)
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
torch.sum(skel_true[:, 1:, ...]) + self.smooth + self.smooth_dr
)
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + self.smooth_dr)
return cl_dice


Expand All @@ -157,28 +166,31 @@ class SoftDiceclDiceLoss(_Loss):
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38
"""

def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None:
def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0, smooth_dr: float = 1e-7) -> None:
"""
Args:
iter_: Number of iterations for skeletonization. Defaults to 3.
alpha: Weighing factor for cldice. Defaults to 0.5.
smooth: Smoothing parameter. Defaults to 1.0.
smooth_dr: Small epsilon added to each ratio denominator to prevent
0/0 when smooth=0 and inputs are entirely zero. Defaults to 1e-7.
Comment on lines +175 to +176
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Document the harmonic-mean denominator here too.

This docstring says smooth_dr is added to “each ratio denominator”, but Line 194 also applies it to the final tprec + tsens + self.smooth_dr term. The parameter docs should match the implementation.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/cldice.py` around lines 175 - 176, Update the docstring for the
parameter smooth_dr in clDice (referencing smooth_dr and the harmonic mean
computation using tprec and tsens) to explicitly state that smooth_dr is added
not only to each individual ratio denominator but also to the harmonic-mean
denominator (the final tprec + tsens + self.smooth_dr term) so the docs match
the implementation; locate the parameter description in the clDice/clDiceLoss
docstring and reword it to mention both the per-ratio denominators and the
combined harmonic denominator.

"""
super().__init__()
self.iter = iter_
self.smooth = smooth
self.alpha = alpha
self.smooth_dr = smooth_dr

def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
dice = soft_dice(y_true, y_pred, self.smooth)
dice = soft_dice(y_true, y_pred, self.smooth, self.smooth_dr)
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
torch.sum(skel_pred[:, 1:, ...]) + self.smooth + self.smooth_dr
)
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
torch.sum(skel_true[:, 1:, ...]) + self.smooth + self.smooth_dr
)
cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + self.smooth_dr)
total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice
return total_loss
25 changes: 25 additions & 0 deletions tests/losses/test_cldice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,31 @@ def test_with_cuda(self):
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)

def test_zero_input_no_nan(self):
"""Test that zero-valued inputs do not produce NaN loss (division by zero guard)."""
loss = SoftclDiceLoss(smooth=0.0)
loss_dice = SoftDiceclDiceLoss(smooth=0.0)
y_pred = torch.zeros((2, 3, 8, 8))
y_true = torch.zeros((2, 3, 8, 8))
result = loss(y_true, y_pred)
result_dice = loss_dice(y_true, y_pred)
self.assertFalse(torch.isnan(result).any(), "SoftclDiceLoss produced NaN for zero inputs with smooth=0")
self.assertFalse(torch.isnan(result_dice).any(), "SoftDiceclDiceLoss produced NaN for zero inputs with smooth=0")

def test_no_overlap_no_nan(self):
"""Test that non-overlapping pred/target do not produce NaN loss."""
loss = SoftclDiceLoss(smooth=0.0)
loss_dice = SoftDiceclDiceLoss(smooth=0.0)
# Create non-overlapping predictions and ground truth
y_pred = torch.zeros((2, 3, 16, 16))
y_true = torch.zeros((2, 3, 16, 16))
y_pred[:, 1:, :8, :] = 1.0 # prediction in left half
y_true[:, 1:, 8:, :] = 1.0 # ground truth in right half
result = loss(y_true, y_pred)
result_dice = loss_dice(y_true, y_pred)
self.assertFalse(torch.isnan(result).any(), "SoftclDiceLoss produced NaN for non-overlapping inputs")
self.assertFalse(torch.isnan(result_dice).any(), "SoftDiceclDiceLoss produced NaN for non-overlapping inputs")


if __name__ == "__main__":
unittest.main()