From b0f9d81953c83bc7eefac18770e13a679bf8c493 Mon Sep 17 00:00:00 2001 From: Jun Hyeok Lee Date: Thu, 2 Apr 2026 23:53:36 +0900 Subject: [PATCH 1/3] fix(auto3dseg): handle precomputed crops and safe no-grad cleanup Signed-off-by: Jun Hyeok Lee --- monai/auto3dseg/analyzer.py | 93 ++++++++++++++++-------------------- tests/apps/test_auto3dseg.py | 47 +++++++++++++++++- 2 files changed, 86 insertions(+), 54 deletions(-) diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index a731546a9e..cbc221d627 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -216,25 +216,10 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS) super().__init__(stats_name, report_format) self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations()) + @torch.no_grad() def __call__(self, data): - # Input Validation Addition - if not isinstance(data, dict): - raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.") - if self.image_key not in data: - raise KeyError(f"Key '{self.image_key}' not found in input data.") - image = data[self.image_key] - if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)): - raise TypeError( - f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, " - f"but got {type(image).__name__}." - ) - if image.ndim < 3: - raise ValueError( - f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}." - ) - # --- End of validation --- """ - Callable to execute the pre-defined functions + Callable to execute the pre-defined functions. Returns: A dictionary. The dict has the key in self.report_format. The value of @@ -242,7 +227,12 @@ def __call__(self, data): has stats pre-defined by SampleOperations (max, min, ....). Raises: - RuntimeError if the stats report generated is not consistent with the pre- + KeyError: if ``self.image_key`` is not present in the input data. + TypeError: if the input data is not a dictionary, or if the image value is + not a numpy array, torch.Tensor, or MetaTensor. + ValueError: if the image has fewer than 3 dimensions, or if pre-computed + ``nda_croppeds`` is not a list/tuple with one entry per image channel. + RuntimeError: if the stats report generated is not consistent with the pre- defined report_format. Note: @@ -250,16 +240,34 @@ def __call__(self, data): functions. If the input has nan/inf, the stats results will be nan/inf. """ + if not isinstance(data, dict): + raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.") + if self.image_key not in data: + raise KeyError(f"Key '{self.image_key}' not found in input data.") + image = data[self.image_key] + if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)): + raise TypeError( + f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, " + f"but got {type(image).__name__}." + ) + if image.ndim < 3: + raise ValueError( + f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}." + ) + d = dict(data) start = time.time() - restore_grad_state = torch.is_grad_enabled() - torch.set_grad_enabled(False) - ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] - if "nda_croppeds" not in d: + if "nda_croppeds" in d: + nda_croppeds = d["nda_croppeds"] + if not isinstance(nda_croppeds, (list, tuple)) or len(nda_croppeds) != len(ndas): + raise ValueError( + "Pre-computed 'nda_croppeds' must be a list or tuple with one entry per image channel " + f"(expected {len(ndas)})." + ) + else: nda_croppeds = [get_foreground_image(nda) for nda in ndas] - # perform calculation report = deepcopy(self.get_report_format()) report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas] @@ -275,16 +283,13 @@ def __call__(self, data): a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) ] - report[ImageStatsKeys.INTENSITY] = [ - self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds - ] + report[ImageStatsKeys.INTENSITY] = [self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds] if not verify_report_format(report, self.get_report_format()): raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") d[self.stats_name] = report - torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get image stats spent {time.time() - start}") return d @@ -321,6 +326,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKe super().__init__(stats_name, report_format) self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations()) + @torch.no_grad() def __call__(self, data: Mapping) -> dict: """ Callable to execute the pre-defined functions @@ -341,9 +347,6 @@ def __call__(self, data: Mapping) -> dict: d = dict(data) start = time.time() - restore_grad_state = torch.is_grad_enabled() - torch.set_grad_enabled(False) - ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] ndas_label = d[self.label_key] # (H,W,D) @@ -353,19 +356,15 @@ def __call__(self, data: Mapping) -> dict: nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas] nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] - # perform calculation report = deepcopy(self.get_report_format()) - report[ImageStatsKeys.INTENSITY] = [ - self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds - ] + report[ImageStatsKeys.INTENSITY] = [self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds] if not verify_report_format(report, self.get_report_format()): raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") d[self.stats_name] = report - torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get foreground image stats spent {time.time() - start}") return d @@ -418,6 +417,7 @@ def __init__( id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.IMAGE_INTST]) self.update_ops_nested_label(id_seq, SampleOperations()) + @torch.no_grad() def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor | dict]: """ Callable to execute the pre-defined functions. @@ -470,19 +470,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe start = time.time() image_tensor = d[self.image_key] label_tensor = d[self.label_key] - # Check if either tensor is on CUDA to determine if we should move both to CUDA for processing using_cuda = any( isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor) ) - restore_grad_state = torch.is_grad_enabled() - torch.set_grad_enabled(False) - if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance( - label_tensor, (MetaTensor, torch.Tensor) - ): + if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(label_tensor, (MetaTensor, torch.Tensor)): if label_tensor.device != image_tensor.device: if using_cuda: - # Move both tensors to CUDA when mixing devices cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device image_tensor = cast(MetaTensor, image_tensor.to(cuda_device)) label_tensor = cast(MetaTensor, label_tensor.to(cuda_device)) @@ -504,7 +498,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe unique_label = unique_label.astype(np.int16).tolist() - label_substats = [] # each element is one label + label_substats = [] pixel_sum = 0 pixel_arr = [] for index in unique_label: @@ -513,17 +507,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe mask_index = ndas_label == index nda_masks = [nda[mask_index] for nda in ndas] - label_dict[LabelStatsKeys.IMAGE_INTST] = [ - self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks - ] + label_dict[LabelStatsKeys.IMAGE_INTST] = [self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks] pixel_count = sum(mask_index) pixel_arr.append(pixel_count) pixel_sum += pixel_count - if self.do_ccp: # apply connected component + if self.do_ccp: if using_cuda: - # The back end of get_label_ccp is CuPy - # which is unable to automatically release CUDA GPU memory held by PyTorch del nda_masks torch.cuda.empty_cache() shape_list, ncomponents = get_label_ccp(mask_index) @@ -538,9 +528,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe report = deepcopy(self.get_report_format()) report[LabelStatsKeys.LABEL_UID] = unique_label - report[LabelStatsKeys.IMAGE_INTST] = [ - self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds - ] + report[LabelStatsKeys.IMAGE_INTST] = [self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds] report[LabelStatsKeys.LABEL] = label_substats if not verify_report_format(report, self.get_report_format()): @@ -548,7 +536,6 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe d[self.stats_name] = report # type: ignore[assignment] - torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get label stats spent {time.time() - start}") return d # type: ignore[return-value] diff --git a/tests/apps/test_auto3dseg.py b/tests/apps/test_auto3dseg.py index 6c840e944d..8b0cc855b9 100644 --- a/tests/apps/test_auto3dseg.py +++ b/tests/apps/test_auto3dseg.py @@ -53,7 +53,7 @@ SqueezeDimd, ToDeviced, ) -from monai.utils.enums import DataStatsKeys, LabelStatsKeys +from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys from tests.test_utils import skip_if_no_cuda device = "cpu" @@ -322,6 +322,18 @@ def test_image_stats_case_analyzer(self): report_format = analyzer.get_report_format() assert verify_report_format(d["image_stats"], report_format) + def test_image_stats_uses_precomputed_nda_croppeds(self): + analyzer = ImageStats(image_key="image") + image = torch.arange(64.0, dtype=torch.float32).reshape(1, 4, 4, 4) + nda_croppeds = [torch.ones((2, 2, 2), dtype=torch.float32)] + + result = analyzer({"image": image, "nda_croppeds": nda_croppeds}) + report = result["image_stats"] + + assert verify_report_format(report, analyzer.get_report_format()) + assert report[ImageStatsKeys.CROPPED_SHAPE] == [[2, 2, 2]] + self.assertAlmostEqual(report[ImageStatsKeys.INTENSITY][0]["mean"], 1.0) + def test_foreground_image_stats_cases_analyzer(self): analyzer = FgImageStats(image_key="image", label_key="label") transform_list = [ @@ -412,6 +424,39 @@ def test_label_stats_mixed_device_analyzer(self, input_params): self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75) self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75) + def test_case_analyzers_restore_grad_state_on_exception(self): + cases = [ + ( + "image_stats", + ImageStats(image_key="image"), + {"image": torch.randn(1, 4, 4, 4), "nda_croppeds": [None]}, + AttributeError, + ), + ( + "fg_image_stats", + FgImageStats(image_key="image", label_key="label"), + {"image": torch.randn(1, 4, 4, 4), "label": torch.ones(3, 4, 4)}, + ValueError, + ), + ( + "label_stats", + LabelStats(image_key="image", label_key="label"), + {"image": MetaTensor(torch.randn(1, 4, 4, 4)), "label": MetaTensor(torch.ones(3, 4, 4))}, + ValueError, + ), + ] + + original_grad_state = torch.is_grad_enabled() + try: + for name, analyzer, data, error in cases: + with self.subTest(analyzer=name): + torch.set_grad_enabled(True) + with self.assertRaises(error): + analyzer(data) + self.assertTrue(torch.is_grad_enabled()) + finally: + torch.set_grad_enabled(original_grad_state) + def test_filename_case_analyzer(self): analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH) analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH) From a975397956fffed064cbac65f8b06b75918e3afe Mon Sep 17 00:00:00 2001 From: Jun Hyeok Lee Date: Wed, 15 Apr 2026 21:26:31 +0900 Subject: [PATCH 2/3] test(auto3dseg): strengthen crop and grad-state coverage Signed-off-by: Jun Hyeok Lee --- monai/auto3dseg/analyzer.py | 6 +++-- tests/apps/test_auto3dseg.py | 45 ++++++++++++++++++++++++++++++------ 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index cbc221d627..94b724bfca 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -498,7 +498,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe unique_label = unique_label.astype(np.int16).tolist() - label_substats = [] + label_substats = [] # each element is one label pixel_sum = 0 pixel_arr = [] for index in unique_label: @@ -512,8 +512,10 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe pixel_count = sum(mask_index) pixel_arr.append(pixel_count) pixel_sum += pixel_count - if self.do_ccp: + if self.do_ccp: # apply connected component if using_cuda: + # The back end of get_label_ccp is CuPy + # which is unable to automatically release CUDA GPU memory held by PyTorch del nda_masks torch.cuda.empty_cache() shape_list, ncomponents = get_label_ccp(mask_index) diff --git a/tests/apps/test_auto3dseg.py b/tests/apps/test_auto3dseg.py index 8b0cc855b9..57e05d1ee6 100644 --- a/tests/apps/test_auto3dseg.py +++ b/tests/apps/test_auto3dseg.py @@ -323,6 +323,7 @@ def test_image_stats_case_analyzer(self): assert verify_report_format(d["image_stats"], report_format) def test_image_stats_uses_precomputed_nda_croppeds(self): + """Verify ImageStats uses valid pre-computed foreground crops.""" analyzer = ImageStats(image_key="image") image = torch.arange(64.0, dtype=torch.float32).reshape(1, 4, 4, 4) nda_croppeds = [torch.ones((2, 2, 2), dtype=torch.float32)] @@ -334,6 +335,34 @@ def test_image_stats_uses_precomputed_nda_croppeds(self): assert report[ImageStatsKeys.CROPPED_SHAPE] == [[2, 2, 2]] self.assertAlmostEqual(report[ImageStatsKeys.INTENSITY][0]["mean"], 1.0) + def test_image_stats_validates_precomputed_nda_croppeds(self): + """Verify ImageStats rejects malformed pre-computed foreground crops.""" + analyzer = ImageStats(image_key="image") + image = torch.ones((2, 4, 4, 4), dtype=torch.float32) + invalid_cases = [ + ("wrong_type", torch.ones((2, 2, 2), dtype=torch.float32)), + ("wrong_length", [torch.ones((2, 2, 2), dtype=torch.float32)]), + ] + + for name, nda_croppeds in invalid_cases: + with self.subTest(case=name): + with self.assertRaisesRegex(ValueError, "one entry per image channel"): + analyzer({"image": image, "nda_croppeds": nda_croppeds}) + + def test_image_stats_preserves_grad_state_after_call(self): + """Verify ImageStats preserves caller grad state on successful execution.""" + analyzer = ImageStats(image_key="image") + data = {"image": MetaTensor(torch.rand(1, 10, 10, 10))} + original_grad_state = torch.is_grad_enabled() + try: + for grad_enabled in (True, False): + with self.subTest(grad_enabled=grad_enabled): + torch.set_grad_enabled(grad_enabled) + analyzer(data) + self.assertEqual(torch.is_grad_enabled(), grad_enabled) + finally: + torch.set_grad_enabled(original_grad_state) + def test_foreground_image_stats_cases_analyzer(self): analyzer = FgImageStats(image_key="image", label_key="label") transform_list = [ @@ -425,12 +454,13 @@ def test_label_stats_mixed_device_analyzer(self, input_params): self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75) def test_case_analyzers_restore_grad_state_on_exception(self): + """Verify analyzer calls restore caller grad state after exceptions.""" cases = [ ( "image_stats", ImageStats(image_key="image"), - {"image": torch.randn(1, 4, 4, 4), "nda_croppeds": [None]}, - AttributeError, + {"image": torch.randn(2, 4, 4, 4), "nda_croppeds": [torch.ones((2, 2, 2))]}, + ValueError, ), ( "fg_image_stats", @@ -449,11 +479,12 @@ def test_case_analyzers_restore_grad_state_on_exception(self): original_grad_state = torch.is_grad_enabled() try: for name, analyzer, data, error in cases: - with self.subTest(analyzer=name): - torch.set_grad_enabled(True) - with self.assertRaises(error): - analyzer(data) - self.assertTrue(torch.is_grad_enabled()) + for grad_enabled in (True, False): + with self.subTest(analyzer=name, grad_enabled=grad_enabled): + torch.set_grad_enabled(grad_enabled) + with self.assertRaises(error): + analyzer(data) + self.assertEqual(torch.is_grad_enabled(), grad_enabled) finally: torch.set_grad_enabled(original_grad_state) From 128452914da1611de12b6501094ada2c9b5d3c7d Mon Sep 17 00:00:00 2001 From: Jun Hyeok Lee Date: Wed, 15 Apr 2026 22:20:29 +0900 Subject: [PATCH 3/3] style(auto3dseg): format analyzer with black Signed-off-by: Jun Hyeok Lee --- monai/auto3dseg/analyzer.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index 94b724bfca..9766d86997 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -283,7 +283,9 @@ def __call__(self, data): a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) ] - report[ImageStatsKeys.INTENSITY] = [self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds] + report[ImageStatsKeys.INTENSITY] = [ + self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds + ] if not verify_report_format(report, self.get_report_format()): raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") @@ -358,7 +360,9 @@ def __call__(self, data: Mapping) -> dict: report = deepcopy(self.get_report_format()) - report[ImageStatsKeys.INTENSITY] = [self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds] + report[ImageStatsKeys.INTENSITY] = [ + self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds + ] if not verify_report_format(report, self.get_report_format()): raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") @@ -474,7 +478,9 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor) ) - if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(label_tensor, (MetaTensor, torch.Tensor)): + if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance( + label_tensor, (MetaTensor, torch.Tensor) + ): if label_tensor.device != image_tensor.device: if using_cuda: cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device @@ -507,7 +513,9 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe mask_index = ndas_label == index nda_masks = [nda[mask_index] for nda in ndas] - label_dict[LabelStatsKeys.IMAGE_INTST] = [self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks] + label_dict[LabelStatsKeys.IMAGE_INTST] = [ + self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks + ] pixel_count = sum(mask_index) pixel_arr.append(pixel_count) @@ -530,7 +538,9 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe report = deepcopy(self.get_report_format()) report[LabelStatsKeys.LABEL_UID] = unique_label - report[LabelStatsKeys.IMAGE_INTST] = [self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds] + report[LabelStatsKeys.IMAGE_INTST] = [ + self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds + ] report[LabelStatsKeys.LABEL] = label_substats if not verify_report_format(report, self.get_report_format()):