diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py index cb5a9e8c7a..1f7ce53c2c 100644 --- a/monai/losses/cldice.py +++ b/monai/losses/cldice.py @@ -92,7 +92,7 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: return skel -def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor: +def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0, smooth_dr: float = 1e-7) -> torch.Tensor: """ Function to compute soft dice loss @@ -102,12 +102,17 @@ def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) - Args: y_true: the shape should be BCH(WD) y_pred: the shape should be BCH(WD) + smooth: Smoothing parameter for numerator and denominator. + smooth_dr: Small epsilon added to the denominator to prevent 0/0 when smooth=0 + and inputs (or skeletonizations) are entirely zero. Defaults to 1e-7. Returns: dice loss """ intersection = torch.sum((y_true * y_pred)[:, 1:, ...]) - coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth) + coeff = (2.0 * intersection + smooth) / ( + torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth + smooth_dr + ) soft_dice: torch.Tensor = 1.0 - coeff return soft_dice @@ -123,26 +128,30 @@ class SoftclDiceLoss(_Loss): https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7 """ - def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None: + def __init__(self, iter_: int = 3, smooth: float = 1.0, smooth_dr: float = 1e-7) -> None: """ Args: iter_: Number of iterations for skeletonization. Defaults to 3. smooth: Smoothing parameter. Defaults to 1.0. + smooth_dr: Small epsilon added to each ratio denominator and to the + harmonic mean denominator to prevent 0/0 when smooth=0 and the + skeleton is entirely zero. Defaults to 1e-7. """ super().__init__() self.iter = iter_ self.smooth = smooth + self.smooth_dr = smooth_dr def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: skel_pred = soft_skel(y_pred, self.iter) skel_true = soft_skel(y_true, self.iter) tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( - torch.sum(skel_pred[:, 1:, ...]) + self.smooth + torch.sum(skel_pred[:, 1:, ...]) + self.smooth + self.smooth_dr ) tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( - torch.sum(skel_true[:, 1:, ...]) + self.smooth + torch.sum(skel_true[:, 1:, ...]) + self.smooth + self.smooth_dr ) - cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + self.smooth_dr) return cl_dice @@ -157,28 +166,31 @@ class SoftDiceclDiceLoss(_Loss): https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38 """ - def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None: + def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0, smooth_dr: float = 1e-7) -> None: """ Args: iter_: Number of iterations for skeletonization. Defaults to 3. alpha: Weighing factor for cldice. Defaults to 0.5. smooth: Smoothing parameter. Defaults to 1.0. + smooth_dr: Small epsilon added to each ratio denominator to prevent + 0/0 when smooth=0 and inputs are entirely zero. Defaults to 1e-7. """ super().__init__() self.iter = iter_ self.smooth = smooth self.alpha = alpha + self.smooth_dr = smooth_dr def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - dice = soft_dice(y_true, y_pred, self.smooth) + dice = soft_dice(y_true, y_pred, self.smooth, self.smooth_dr) skel_pred = soft_skel(y_pred, self.iter) skel_true = soft_skel(y_true, self.iter) tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( - torch.sum(skel_pred[:, 1:, ...]) + self.smooth + torch.sum(skel_pred[:, 1:, ...]) + self.smooth + self.smooth_dr ) tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( - torch.sum(skel_true[:, 1:, ...]) + self.smooth + torch.sum(skel_true[:, 1:, ...]) + self.smooth + self.smooth_dr ) - cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + self.smooth_dr) total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice return total_loss diff --git a/tests/losses/test_cldice_loss.py b/tests/losses/test_cldice_loss.py index 14d3575e3b..33430d7480 100644 --- a/tests/losses/test_cldice_loss.py +++ b/tests/losses/test_cldice_loss.py @@ -46,6 +46,31 @@ def test_with_cuda(self): np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + def test_zero_input_no_nan(self): + """Test that zero-valued inputs do not produce NaN loss (division by zero guard).""" + loss = SoftclDiceLoss(smooth=0.0) + loss_dice = SoftDiceclDiceLoss(smooth=0.0) + y_pred = torch.zeros((2, 3, 8, 8)) + y_true = torch.zeros((2, 3, 8, 8)) + result = loss(y_true, y_pred) + result_dice = loss_dice(y_true, y_pred) + self.assertFalse(torch.isnan(result).any(), "SoftclDiceLoss produced NaN for zero inputs with smooth=0") + self.assertFalse(torch.isnan(result_dice).any(), "SoftDiceclDiceLoss produced NaN for zero inputs with smooth=0") + + def test_no_overlap_no_nan(self): + """Test that non-overlapping pred/target do not produce NaN loss.""" + loss = SoftclDiceLoss(smooth=0.0) + loss_dice = SoftDiceclDiceLoss(smooth=0.0) + # Create non-overlapping predictions and ground truth + y_pred = torch.zeros((2, 3, 16, 16)) + y_true = torch.zeros((2, 3, 16, 16)) + y_pred[:, 1:, :8, :] = 1.0 # prediction in left half + y_true[:, 1:, 8:, :] = 1.0 # ground truth in right half + result = loss(y_true, y_pred) + result_dice = loss_dice(y_true, y_pred) + self.assertFalse(torch.isnan(result).any(), "SoftclDiceLoss produced NaN for non-overlapping inputs") + self.assertFalse(torch.isnan(result_dice).any(), "SoftDiceclDiceLoss produced NaN for non-overlapping inputs") + if __name__ == "__main__": unittest.main()