diff --git a/tests/pytorch/test_checkpoint.py b/tests/pytorch/test_checkpoint.py index 0427886b84..62a68487f8 100644 --- a/tests/pytorch/test_checkpoint.py +++ b/tests/pytorch/test_checkpoint.py @@ -39,6 +39,18 @@ ) +def should_skip(name): + # Skip if quantization is not supported + quantization = None + if "." in name: + quantization = name.split(".")[1] + if quantization == "fp8" and not fp8_available: + return reason_for_no_fp8 + if quantization == "mxfp8" and not mxfp8_available: + return reason_for_no_mxfp8 + return None + + class TestLoadCheckpoint: """Tests for loading checkpoint files @@ -98,6 +110,10 @@ def _checkpoint_dir() -> pathlib.Path: def _save_checkpoint(name: str, checkpoint_dir: Optional[pathlib.Path] = None) -> None: """Save a module's checkpoint file""" + skip_reason = should_skip(name) + if skip_reason is not None: + pytest.skip(skip_reason) + # Path to save checkpoint if checkpoint_dir is None: checkpoint_dir = TestLoadCheckpoint._checkpoint_dir() @@ -113,14 +129,9 @@ def _save_checkpoint(name: str, checkpoint_dir: Optional[pathlib.Path] = None) - def test_module(self, name: str) -> None: """Test for loading a module's checkpoint file""" - # Skip if quantization is not supported - quantization = None - if "." in name: - quantization = name.split(".")[1] - if quantization == "fp8" and not fp8_available: - pytest.skip(reason_for_no_fp8) - if quantization == "mxfp8" and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) + skip_reason = should_skip(name) + if skip_reason is not None: + pytest.skip(skip_reason) # Construct module module = self._make_module(name)