From d1ae627c8d83a8d9665b0118dbbb24c28e166fb2 Mon Sep 17 00:00:00 2001 From: snow_xu Date: Fri, 8 May 2026 10:59:49 +0800 Subject: [PATCH 1/5] fix(patch): remove duplicate @patch_function on _patch_graph_context_manager --- src/torchada/_patch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchada/_patch.py b/src/torchada/_patch.py index 52e184c..560b304 100644 --- a/src/torchada/_patch.py +++ b/src/torchada/_patch.py @@ -389,7 +389,6 @@ def _patch_torch_generator(): _original_graph_class = None -@patch_function def _patch_graph_context_manager(): """ Patch torch.cuda.graph context manager to accept cuda_graph= keyword argument. From 356e1d712fd302430dc961a652cade4c2d70d39f Mon Sep 17 00:00:00 2001 From: snow_xu Date: Fri, 8 May 2026 17:06:24 +0800 Subject: [PATCH 2/5] fix(accelerator): override memory APIs to use torch.musa for MUSA compatibility --- src/torchada/_patch.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/torchada/_patch.py b/src/torchada/_patch.py index 560b304..32b89f1 100644 --- a/src/torchada/_patch.py +++ b/src/torchada/_patch.py @@ -1602,6 +1602,26 @@ def _patch_torch_accelerator(): if not hasattr(_original_torch_accelerator, "stream"): wrapper._set_override("stream", stream_cm) + # Memory APIs that exist on torch.accelerator (PyTorch 2.9+) but internally + # call torch._C._accelerator_* C++ functions which fail on MUSA because the + # MUSA allocator is not a CUDA DeviceAllocator. Override them to delegate to + # torch.musa, following the same pattern as synchronize(). + _ACCELERATOR_MEMORY_APIS = [ + "empty_cache", + "empty_host_cache", + "memory_stats", + "memory_allocated", + "max_memory_allocated", + "memory_reserved", + "max_memory_reserved", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", + "get_memory_info", + ] + for api_name in _ACCELERATOR_MEMORY_APIS: + if hasattr(_original_torch_accelerator, api_name) and hasattr(torch.musa, api_name): + wrapper._set_override(api_name, getattr(torch.musa, api_name)) + sys.modules["torch.accelerator"] = wrapper torch.accelerator = wrapper From a04789d566d0d3c7ba89d78400da4d1ce4fd3741 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Fri, 8 May 2026 17:53:22 +0800 Subject: [PATCH 3/5] Refine Signed-off-by: Xiaodong Ye --- src/torchada/_patch.py | 74 +++++++++++++++++++++++-------------- tests/test_cuda_patching.py | 23 ++++++++++-- 2 files changed, 66 insertions(+), 31 deletions(-) diff --git a/src/torchada/_patch.py b/src/torchada/_patch.py index 32b89f1..645e465 100644 --- a/src/torchada/_patch.py +++ b/src/torchada/_patch.py @@ -1412,7 +1412,8 @@ class _AcceleratorModuleWrapper(ModuleType): Resolution order for attribute access: 1. Explicit overrides installed by torchada (e.g. patched synchronize, - device_index / stream context managers) + device_index / stream context managers, and memory APIs that exist + upstream but are broken on MUSA) 2. The original torch.accelerator module (so existing APIs keep their real implementations) 3. torch.musa as a fallback for APIs that have not yet been added to @@ -1442,12 +1443,39 @@ class _AcceleratorModuleWrapper(ModuleType): "StreamContext": "core.stream.StreamContext", } + # Memory APIs that exist on torch.accelerator (PyTorch 2.9+) but internally + # call torch._C._accelerator_* C++ functions which fail on MUSA because the + # MUSA allocator is not a CUDA DeviceAllocator. These are overridden to + # delegate to torch.musa, following the same pattern as synchronize(). + # When an API in this list exists on the original torch.accelerator AND on + # torch.musa, we install an override that prefers torch.musa over the + # upstream implementation. + _MUSA_OVERRIDES = ( + "empty_cache", + "empty_host_cache", + "memory_stats", + "memory_allocated", + "max_memory_allocated", + "memory_reserved", + "max_memory_reserved", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", + "get_memory_info", + ) + def __init__(self, original_accel, musa_module): super().__init__("torch.accelerator") self._original_accel = original_accel self._musa_module = musa_module self._overrides = {} + # Apply MUSA overrides for memory APIs that exist upstream but are + # broken on MUSA (they route through torch._C._accelerator_* which + # doesn't dispatch to the MUSA allocator). + for name in self._MUSA_OVERRIDES: + if hasattr(original_accel, name) and hasattr(musa_module, name): + self._set_override(name, getattr(musa_module, name)) + def _set_override(self, name, value): """Install an override that takes precedence over the wrapped modules.""" self._overrides[name] = value @@ -1577,14 +1605,26 @@ def _patch_torch_accelerator(): implementation raises. The wrapper installs a patched synchronize that delegates to torch.musa.synchronize(). - 2. Forward compatibility for APIs that PyTorch is expected to add to - torch.accelerator in future releases (empty_cache, memory_stats, - memory_allocated, Stream, Event, manual_seed, get_device_name, ...). - Any attribute missing from the current torch.accelerator module is - looked up on torch.musa instead. + 2. Overrides for memory APIs that exist on torch.accelerator (PyTorch 2.9+) + but are broken on MUSA because they route through torch._C._accelerator_* + C++ functions that don't dispatch to the MUSA allocator. These are + redirected to torch.musa implementations (see _AcceleratorModuleWrapper + ._MUSA_OVERRIDES). + + 3. Forward compatibility for APIs that PyTorch is expected to add to + torch.accelerator in future releases but are not yet present (Stream, + Event, manual_seed, get_device_name, ...). Any attribute missing from + the current torch.accelerator module is looked up on torch.musa instead. - 3. device_index(idx) and stream(s) context managers, which are not yet + 4. device_index(idx) and stream(s) context managers, which are not yet present on torch.accelerator in torch 2.7. + + TODO(torchada): README.md / README_CN.md claim "the wrapper always prefers + the real torch.accelerator implementation and only falls back to torch.musa + when an attribute is missing". That is no longer accurate after adding the + memory API overrides (point 2 above). Update those documents to describe + the actual resolution order: (1) torchada overrides, (2) real torch.accelerator, + (3) fallback to torch.musa. """ global _original_torch_accelerator @@ -1602,26 +1642,6 @@ def _patch_torch_accelerator(): if not hasattr(_original_torch_accelerator, "stream"): wrapper._set_override("stream", stream_cm) - # Memory APIs that exist on torch.accelerator (PyTorch 2.9+) but internally - # call torch._C._accelerator_* C++ functions which fail on MUSA because the - # MUSA allocator is not a CUDA DeviceAllocator. Override them to delegate to - # torch.musa, following the same pattern as synchronize(). - _ACCELERATOR_MEMORY_APIS = [ - "empty_cache", - "empty_host_cache", - "memory_stats", - "memory_allocated", - "max_memory_allocated", - "memory_reserved", - "max_memory_reserved", - "reset_accumulated_memory_stats", - "reset_peak_memory_stats", - "get_memory_info", - ] - for api_name in _ACCELERATOR_MEMORY_APIS: - if hasattr(_original_torch_accelerator, api_name) and hasattr(torch.musa, api_name): - wrapper._set_override(api_name, getattr(torch.musa, api_name)) - sys.modules["torch.accelerator"] = wrapper torch.accelerator = wrapper diff --git a/tests/test_cuda_patching.py b/tests/test_cuda_patching.py index 9143ef6..8eaacb3 100644 --- a/tests/test_cuda_patching.py +++ b/tests/test_cuda_patching.py @@ -2487,17 +2487,32 @@ def _make_wrapper(self, accel_attrs=None, musa_attrs=None): return _AcceleratorModuleWrapper(accel, musa), accel, musa def test_original_accelerator_takes_precedence_over_musa(self): - """Official torch.accelerator implementations must win over torch.musa. + """Official torch.accelerator implementations win over torch.musa for non-overridden APIs. This is the forward-compat guarantee: when PyTorch 2.9+ adds an - official implementation of an API (e.g. empty_cache), the wrapper - must return the official one, not the torch.musa fallback. + official implementation of an API (e.g. manual_seed) that is NOT in + _MUSA_OVERRIDES, the wrapper must return the official one, not the + torch.musa fallback. + """ + wrapper, _, _ = self._make_wrapper( + accel_attrs={"manual_seed": "official_impl"}, + musa_attrs={"manual_seed": "musa_fallback"}, + ) + assert wrapper.manual_seed == "official_impl" + + def test_musa_overrides_take_precedence_when_both_exist(self): + """Memory APIs in _MUSA_OVERRIDES use torch.musa even when torch.accelerator has them. + + Starting in PyTorch 2.9+, torch.accelerator.empty_cache() exists but + routes through torch._C._accelerator_* which doesn't work with the MUSA + allocator. The wrapper must override it to use torch.musa.empty_cache(). """ wrapper, _, _ = self._make_wrapper( accel_attrs={"empty_cache": "official_impl"}, musa_attrs={"empty_cache": "musa_fallback"}, ) - assert wrapper.empty_cache == "official_impl" + # empty_cache is in _MUSA_OVERRIDES, so torch.musa wins + assert wrapper.empty_cache == "musa_fallback" def test_fallback_to_musa_when_accelerator_missing(self): """Attributes absent from torch.accelerator must fall back to torch.musa.""" From 38a11f38c2f7b2e8a2ced1fa7ae49e818763e288 Mon Sep 17 00:00:00 2001 From: popsiclexu Date: Mon, 11 May 2026 14:03:24 +0800 Subject: [PATCH 4/5] fix(patch): update CUDA matmul backend for MUSA compatibility --- src/torchada/_patch.py | 7 +++++++ tests/test_cuda_patching.py | 5 ++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/torchada/_patch.py b/src/torchada/_patch.py index 645e465..e878775 100644 --- a/src/torchada/_patch.py +++ b/src/torchada/_patch.py @@ -1167,6 +1167,13 @@ def patched_is_built(): torch.backends.cuda.is_built = patched_is_built + if ( + is_musa_platform() + and hasattr(torch.backends, "musa") + and hasattr(torch.backends.musa, "matmul") + ): + torch.backends.cuda.matmul = torch.backends.musa.matmul + # Patch cuBLASModule to support fp32_precision attribute # This attribute is in newer PyTorch but may be missing in torch_musa's version matmul = torch.backends.cuda.matmul diff --git a/tests/test_cuda_patching.py b/tests/test_cuda_patching.py index 8eaacb3..597e0fd 100644 --- a/tests/test_cuda_patching.py +++ b/tests/test_cuda_patching.py @@ -1258,7 +1258,10 @@ def test_torch_backends_cuda_matmul_fp32_precision(self): except (AttributeError, AssertionError): pytest.skip("fp32_precision not available (torchada MUSA-specific attribute)") - if torch.__version__ >= torch.torch_version.TorchVersion("2.9.0"): + if ( + not torchada.is_musa_platform() + and torch.__version__ >= torch.torch_version.TorchVersion("2.9.0") + ): # PyTorch 2.9+: Only use the new API. Do NOT call torch.get_float32_matmul_precision() valid_precisions = ("ieee", "tf32") test_values = ["ieee", "tf32"] From 15d5d576200e7c0a93549c76544014ccfacdb630 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Mon, 11 May 2026 14:20:08 +0800 Subject: [PATCH 5/5] Bump version Signed-off-by: Xiaodong Ye --- README.md | 2 +- README_CN.md | 2 +- benchmarks/benchmark_history.json | 2 +- pyproject.toml | 2 +- src/torchada/__init__.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 502e1b2..ef2b636 100644 --- a/README.md +++ b/README.md @@ -297,7 +297,7 @@ See `src/torchada/_mapping.py` for the complete mapping table (380+ mappings). ``` # pyproject.toml or requirements.txt -torchada>=0.1.54 +torchada>=0.1.55 ``` ### Step 2: Conditional Import diff --git a/README_CN.md b/README_CN.md index e48ed2d..3207e54 100644 --- a/README_CN.md +++ b/README_CN.md @@ -292,7 +292,7 @@ if torchada.is_gpu_device(device): # 在 CUDA 和 MUSA 上都能工作 ``` # pyproject.toml 或 requirements.txt -torchada>=0.1.54 +torchada>=0.1.55 ``` ### 步骤 2:条件导入 diff --git a/benchmarks/benchmark_history.json b/benchmarks/benchmark_history.json index e024d65..ba4fef7 100644 --- a/benchmarks/benchmark_history.json +++ b/benchmarks/benchmark_history.json @@ -3,7 +3,7 @@ "description": "Historical benchmark results for torchada performance tracking", "results": [ { - "version": "0.1.54", + "version": "0.1.55", "date": "2026-01-29", "platform": "MUSA", "pytorch_version": "2.7.1", diff --git a/pyproject.toml b/pyproject.toml index d7e2744..6e8046a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "torchada" -version = "0.1.54" +version = "0.1.55" description = "Adapter package for torch_musa to act exactly like PyTorch CUDA" readme = "README.md" license = {text = "MIT"} diff --git a/src/torchada/__init__.py b/src/torchada/__init__.py index 0c6fd45..8e55eee 100644 --- a/src/torchada/__init__.py +++ b/src/torchada/__init__.py @@ -24,7 +24,7 @@ from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CUDA_HOME """ -__version__ = "0.1.54" +__version__ = "0.1.55" from . import cuda, utils