Skip to content
98 changes: 65 additions & 33 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,25 +355,6 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
data = data.reshape(-1, 1)
return torch.stack([data.mean()])

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
for each batch item and for each channel of those items.

Args:
y_pred: input predictions with shape HW[D].
y: ground truth with shape HW[D].
"""
y_o = torch.sum(y)
if y_o > 0:
return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred))
if self.ignore_empty:
return torch.tensor(float("nan"), device=y_o.device)
denorm = y_o + torch.sum(y_pred)
if denorm <= 0:
return torch.tensor(1.0, device=y_o.device)
return torch.tensor(0.0, device=y_o.device)

def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Compute the metric for the given prediction and ground truth.
Expand Down Expand Up @@ -413,21 +394,72 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
"(B, 2, H, W) or (B, 2, D, H, W). "
f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}."
)
data = torch.stack(
[
self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0)).reshape(-1)
for b in range(y_pred.shape[0])
],
dim=0,
).contiguous()
f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
return (f, not_nans) if self.get_not_nans else f

# Vectorized computation (replaces nested loops for better performance)
batch_size = y_pred.shape[0]
device = y_pred.device

# Convert y_pred to boolean (handle single-channel class indices vs multi-channel one-hot independently)
if y_pred.shape[1] == 1 and n_pred_ch > 1:
y_pred_bool = torch.zeros(batch_size, n_pred_ch, *y_pred.shape[2:], dtype=torch.bool, device=device)
for c in range(n_pred_ch):
y_pred_bool[:, c] = y_pred[:, 0] == c
else:
y_pred_bool = y_pred.bool()

# Convert y: single-channel class indices → one-hot bool; multi-channel → preserve raw values
if y.shape[1] == 1 and n_pred_ch > 1:
y_expanded = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.float32, device=device)
for c in range(n_pred_ch):
y_expanded[:, c] = (y[:, 0] == c).float()
else:
y_expanded = y

first_ch = 0 if self.include_background and not self.per_component else 1
data = []
for b in range(y_pred.shape[0]):
if self.per_component:
data.append(self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0)).reshape(-1))
continue
c_list = []
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
c_list.append(self.compute_channel(x_pred, x))
data.append(torch.stack(c_list))

data = torch.stack(data, dim=0).contiguous() # type: ignore
# Flatten spatial dimensions for vectorized computation: (batch, channels, -1)
y_pred_flat = y_pred_bool.reshape(batch_size, n_pred_ch, -1).float()
y_flat = y_expanded.reshape(batch_size, n_pred_ch, -1).float()

# Compute Dice per (batch, channel) vectorized: all reductions at once
intersection = torch.sum(y_pred_flat * y_flat, dim=-1) # (batch, n_pred_ch)
pred_sum = torch.sum(y_pred_flat, dim=-1) # (batch, n_pred_ch)
y_sum = torch.sum(y_flat, dim=-1) # (batch, n_pred_ch)

# Dice formula: 2 * intersection / (pred_sum + y_sum)
union = pred_sum + y_sum
dice = (2.0 * intersection) / union # (batch, n_pred_ch)

# Handle empty ground truth cases
if self.ignore_empty:
# Set NaN where ground truth is empty
dice = torch.where(y_sum > 0, dice, torch.tensor(float("nan"), device=device, dtype=dice.dtype))
else:
# Set 1.0 if both empty, 0.0 if only pred is non-empty
empty_mask = y_sum == 0
dice = torch.where(
empty_mask,
torch.where(
pred_sum == 0,
torch.tensor(1.0, device=device, dtype=dice.dtype),
torch.tensor(0.0, device=device, dtype=dice.dtype),
),
dice,
)

# Select channels: exclude background if requested
first_ch = 0 if self.include_background else 1
if n_pred_ch > 1:
data = dice[:, first_ch:] # (batch, num_classes_selected)
else:
data = dice # (batch, 1)

f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
return (f, not_nans) if self.get_not_nans else f
55 changes: 53 additions & 2 deletions tests/metrics/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,66 @@
]


# single-channel y (class indices) with multi-channel y_pred (one-hot)
TEST_CASE_MIXED_1 = [
{
"y_pred": torch.tensor(
[[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]]]]
), # (1, 3, 2, 2) one-hot
"y": torch.tensor([[[[0.0, 1.0], [2.0, 1.0]]]]), # (1, 1, 2, 2) class indices
"include_background": True,
},
# class 0: y_gt=[[1,0],[0,0]], y_pred=[[0,1],[0,0]] -> dice=0.0
# class 1: y_gt=[[0,1],[0,1]], y_pred=[[0,0],[0,1]] -> dice=2/3
# class 2: y_gt=[[0,0],[1,0]], y_pred=[[1,0],[1,0]] -> dice=2/3
[[0.0000, 0.6667, 0.6667]],
]

# single-channel y_pred (argmaxed, with num_classes) with multi-channel y (one-hot)
TEST_CASE_MIXED_2 = [
{
"y_pred": torch.tensor([[[[2.0, 2.0], [2.0, 2.0]]]]), # (1, 1, 2, 2) all class 2
"y": torch.tensor(
[[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]
), # (1, 3, 2, 2) one-hot, all background
"include_background": True,
"num_classes": 3,
},
# class 0: y_gt=[1,1,1,1](4), y_pred=[0,0,0,0](0) -> dice=0.0
# class 1: y_gt=[0,0,0,0](0), y_pred=[0,0,0,0](0) -> dice=nan (ignore_empty default)
# class 2: y_gt=[0,0,0,0](0), y_pred=[1,1,1,1](4) -> dice=nan (ignore_empty default)
[[False, True, True]], # False=not-nan, True=nan
]

# single-channel y (class indices) with multi-channel y_pred, exclude background
TEST_CASE_MIXED_3 = [
{
"y_pred": torch.tensor(
[
[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
), # (2, 3, 2, 2) one-hot
"y": torch.tensor([[[[0.0, 0.0], [0.0, 1.0]]], [[[0.0, 0.0], [0.0, 1.0]]]]), # (2, 1, 2, 2) class indices
"include_background": False,
},
# batch 0: class 1 y_gt=[[0,0],[0,1]], y_pred=[[0,0],[1,1]] -> dice=2/3
# class 2 y_gt=[[0,0],[0,0]], y_pred=[[1,0],[0,0]] -> dice=nan
# batch 1: class 1 y_gt=[[0,0],[0,1]], y_pred=[[1,0],[0,0]] -> dice=0.0
# class 2 y_gt=[[0,0],[0,0]], y_pred=[[0,1],[1,0]] -> dice=nan
[[False, True], [False, True]], # nan pattern
]


class TestComputeMeanDice(unittest.TestCase):

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12, TEST_CASE_MIXED_1])
def test_value(self, input_data, expected_value):
result = compute_dice(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
np.testing.assert_equal(result.device, input_data["y_pred"].device)

@parameterized.expand([TEST_CASE_3])
@parameterized.expand([TEST_CASE_3, TEST_CASE_MIXED_2, TEST_CASE_MIXED_3])
def test_nans(self, input_data, expected_value):
result = compute_dice(**input_data)
self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))
Expand Down
Loading