diff --git a/README.md b/README.md index 502e1b2..ef2b636 100644 --- a/README.md +++ b/README.md @@ -297,7 +297,7 @@ See `src/torchada/_mapping.py` for the complete mapping table (380+ mappings). ``` # pyproject.toml or requirements.txt -torchada>=0.1.54 +torchada>=0.1.55 ``` ### Step 2: Conditional Import diff --git a/README_CN.md b/README_CN.md index e48ed2d..3207e54 100644 --- a/README_CN.md +++ b/README_CN.md @@ -292,7 +292,7 @@ if torchada.is_gpu_device(device): # 在 CUDA 和 MUSA 上都能工作 ``` # pyproject.toml 或 requirements.txt -torchada>=0.1.54 +torchada>=0.1.55 ``` ### 步骤 2:条件导入 diff --git a/benchmarks/benchmark_history.json b/benchmarks/benchmark_history.json index e024d65..ba4fef7 100644 --- a/benchmarks/benchmark_history.json +++ b/benchmarks/benchmark_history.json @@ -3,7 +3,7 @@ "description": "Historical benchmark results for torchada performance tracking", "results": [ { - "version": "0.1.54", + "version": "0.1.55", "date": "2026-01-29", "platform": "MUSA", "pytorch_version": "2.7.1", diff --git a/pyproject.toml b/pyproject.toml index d7e2744..6e8046a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "torchada" -version = "0.1.54" +version = "0.1.55" description = "Adapter package for torch_musa to act exactly like PyTorch CUDA" readme = "README.md" license = {text = "MIT"} diff --git a/src/torchada/__init__.py b/src/torchada/__init__.py index 0c6fd45..8e55eee 100644 --- a/src/torchada/__init__.py +++ b/src/torchada/__init__.py @@ -24,7 +24,7 @@ from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CUDA_HOME """ -__version__ = "0.1.54" +__version__ = "0.1.55" from . import cuda, utils diff --git a/src/torchada/_patch.py b/src/torchada/_patch.py index 52e184c..e878775 100644 --- a/src/torchada/_patch.py +++ b/src/torchada/_patch.py @@ -389,7 +389,6 @@ def _patch_torch_generator(): _original_graph_class = None -@patch_function def _patch_graph_context_manager(): """ Patch torch.cuda.graph context manager to accept cuda_graph= keyword argument. @@ -1168,6 +1167,13 @@ def patched_is_built(): torch.backends.cuda.is_built = patched_is_built + if ( + is_musa_platform() + and hasattr(torch.backends, "musa") + and hasattr(torch.backends.musa, "matmul") + ): + torch.backends.cuda.matmul = torch.backends.musa.matmul + # Patch cuBLASModule to support fp32_precision attribute # This attribute is in newer PyTorch but may be missing in torch_musa's version matmul = torch.backends.cuda.matmul @@ -1413,7 +1419,8 @@ class _AcceleratorModuleWrapper(ModuleType): Resolution order for attribute access: 1. Explicit overrides installed by torchada (e.g. patched synchronize, - device_index / stream context managers) + device_index / stream context managers, and memory APIs that exist + upstream but are broken on MUSA) 2. The original torch.accelerator module (so existing APIs keep their real implementations) 3. torch.musa as a fallback for APIs that have not yet been added to @@ -1443,12 +1450,39 @@ class _AcceleratorModuleWrapper(ModuleType): "StreamContext": "core.stream.StreamContext", } + # Memory APIs that exist on torch.accelerator (PyTorch 2.9+) but internally + # call torch._C._accelerator_* C++ functions which fail on MUSA because the + # MUSA allocator is not a CUDA DeviceAllocator. These are overridden to + # delegate to torch.musa, following the same pattern as synchronize(). + # When an API in this list exists on the original torch.accelerator AND on + # torch.musa, we install an override that prefers torch.musa over the + # upstream implementation. + _MUSA_OVERRIDES = ( + "empty_cache", + "empty_host_cache", + "memory_stats", + "memory_allocated", + "max_memory_allocated", + "memory_reserved", + "max_memory_reserved", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", + "get_memory_info", + ) + def __init__(self, original_accel, musa_module): super().__init__("torch.accelerator") self._original_accel = original_accel self._musa_module = musa_module self._overrides = {} + # Apply MUSA overrides for memory APIs that exist upstream but are + # broken on MUSA (they route through torch._C._accelerator_* which + # doesn't dispatch to the MUSA allocator). + for name in self._MUSA_OVERRIDES: + if hasattr(original_accel, name) and hasattr(musa_module, name): + self._set_override(name, getattr(musa_module, name)) + def _set_override(self, name, value): """Install an override that takes precedence over the wrapped modules.""" self._overrides[name] = value @@ -1578,14 +1612,26 @@ def _patch_torch_accelerator(): implementation raises. The wrapper installs a patched synchronize that delegates to torch.musa.synchronize(). - 2. Forward compatibility for APIs that PyTorch is expected to add to - torch.accelerator in future releases (empty_cache, memory_stats, - memory_allocated, Stream, Event, manual_seed, get_device_name, ...). - Any attribute missing from the current torch.accelerator module is - looked up on torch.musa instead. + 2. Overrides for memory APIs that exist on torch.accelerator (PyTorch 2.9+) + but are broken on MUSA because they route through torch._C._accelerator_* + C++ functions that don't dispatch to the MUSA allocator. These are + redirected to torch.musa implementations (see _AcceleratorModuleWrapper + ._MUSA_OVERRIDES). - 3. device_index(idx) and stream(s) context managers, which are not yet + 3. Forward compatibility for APIs that PyTorch is expected to add to + torch.accelerator in future releases but are not yet present (Stream, + Event, manual_seed, get_device_name, ...). Any attribute missing from + the current torch.accelerator module is looked up on torch.musa instead. + + 4. device_index(idx) and stream(s) context managers, which are not yet present on torch.accelerator in torch 2.7. + + TODO(torchada): README.md / README_CN.md claim "the wrapper always prefers + the real torch.accelerator implementation and only falls back to torch.musa + when an attribute is missing". That is no longer accurate after adding the + memory API overrides (point 2 above). Update those documents to describe + the actual resolution order: (1) torchada overrides, (2) real torch.accelerator, + (3) fallback to torch.musa. """ global _original_torch_accelerator diff --git a/tests/test_cuda_patching.py b/tests/test_cuda_patching.py index 9143ef6..597e0fd 100644 --- a/tests/test_cuda_patching.py +++ b/tests/test_cuda_patching.py @@ -1258,7 +1258,10 @@ def test_torch_backends_cuda_matmul_fp32_precision(self): except (AttributeError, AssertionError): pytest.skip("fp32_precision not available (torchada MUSA-specific attribute)") - if torch.__version__ >= torch.torch_version.TorchVersion("2.9.0"): + if ( + not torchada.is_musa_platform() + and torch.__version__ >= torch.torch_version.TorchVersion("2.9.0") + ): # PyTorch 2.9+: Only use the new API. Do NOT call torch.get_float32_matmul_precision() valid_precisions = ("ieee", "tf32") test_values = ["ieee", "tf32"] @@ -2487,17 +2490,32 @@ def _make_wrapper(self, accel_attrs=None, musa_attrs=None): return _AcceleratorModuleWrapper(accel, musa), accel, musa def test_original_accelerator_takes_precedence_over_musa(self): - """Official torch.accelerator implementations must win over torch.musa. + """Official torch.accelerator implementations win over torch.musa for non-overridden APIs. This is the forward-compat guarantee: when PyTorch 2.9+ adds an - official implementation of an API (e.g. empty_cache), the wrapper - must return the official one, not the torch.musa fallback. + official implementation of an API (e.g. manual_seed) that is NOT in + _MUSA_OVERRIDES, the wrapper must return the official one, not the + torch.musa fallback. + """ + wrapper, _, _ = self._make_wrapper( + accel_attrs={"manual_seed": "official_impl"}, + musa_attrs={"manual_seed": "musa_fallback"}, + ) + assert wrapper.manual_seed == "official_impl" + + def test_musa_overrides_take_precedence_when_both_exist(self): + """Memory APIs in _MUSA_OVERRIDES use torch.musa even when torch.accelerator has them. + + Starting in PyTorch 2.9+, torch.accelerator.empty_cache() exists but + routes through torch._C._accelerator_* which doesn't work with the MUSA + allocator. The wrapper must override it to use torch.musa.empty_cache(). """ wrapper, _, _ = self._make_wrapper( accel_attrs={"empty_cache": "official_impl"}, musa_attrs={"empty_cache": "musa_fallback"}, ) - assert wrapper.empty_cache == "official_impl" + # empty_cache is in _MUSA_OVERRIDES, so torch.musa wins + assert wrapper.empty_cache == "musa_fallback" def test_fallback_to_musa_when_accelerator_missing(self): """Attributes absent from torch.accelerator must fall back to torch.musa."""