Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions src/torchada/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def _patch_torch_generator():
_original_graph_class = None


@patch_function
def _patch_graph_context_manager():
"""
Patch torch.cuda.graph context manager to accept cuda_graph= keyword argument.
Expand Down Expand Up @@ -1413,7 +1412,8 @@ class _AcceleratorModuleWrapper(ModuleType):

Resolution order for attribute access:
1. Explicit overrides installed by torchada (e.g. patched synchronize,
device_index / stream context managers)
device_index / stream context managers, and memory APIs that exist
upstream but are broken on MUSA)
2. The original torch.accelerator module (so existing APIs keep their
real implementations)
3. torch.musa as a fallback for APIs that have not yet been added to
Expand Down Expand Up @@ -1443,12 +1443,39 @@ class _AcceleratorModuleWrapper(ModuleType):
"StreamContext": "core.stream.StreamContext",
}

# Memory APIs that exist on torch.accelerator (PyTorch 2.9+) but internally
# call torch._C._accelerator_* C++ functions which fail on MUSA because the
# MUSA allocator is not a CUDA DeviceAllocator. These are overridden to
# delegate to torch.musa, following the same pattern as synchronize().
# When an API in this list exists on the original torch.accelerator AND on
# torch.musa, we install an override that prefers torch.musa over the
# upstream implementation.
_MUSA_OVERRIDES = (
"empty_cache",
"empty_host_cache",
"memory_stats",
"memory_allocated",
"max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"reset_accumulated_memory_stats",
"reset_peak_memory_stats",
"get_memory_info",
)

def __init__(self, original_accel, musa_module):
super().__init__("torch.accelerator")
self._original_accel = original_accel
self._musa_module = musa_module
self._overrides = {}

# Apply MUSA overrides for memory APIs that exist upstream but are
# broken on MUSA (they route through torch._C._accelerator_* which
# doesn't dispatch to the MUSA allocator).
for name in self._MUSA_OVERRIDES:
if hasattr(original_accel, name) and hasattr(musa_module, name):
self._set_override(name, getattr(musa_module, name))

def _set_override(self, name, value):
"""Install an override that takes precedence over the wrapped modules."""
self._overrides[name] = value
Expand Down Expand Up @@ -1578,14 +1605,26 @@ def _patch_torch_accelerator():
implementation raises. The wrapper installs a patched synchronize that
delegates to torch.musa.synchronize().

2. Forward compatibility for APIs that PyTorch is expected to add to
torch.accelerator in future releases (empty_cache, memory_stats,
memory_allocated, Stream, Event, manual_seed, get_device_name, ...).
Any attribute missing from the current torch.accelerator module is
looked up on torch.musa instead.
2. Overrides for memory APIs that exist on torch.accelerator (PyTorch 2.9+)
but are broken on MUSA because they route through torch._C._accelerator_*
C++ functions that don't dispatch to the MUSA allocator. These are
redirected to torch.musa implementations (see _AcceleratorModuleWrapper
._MUSA_OVERRIDES).

3. Forward compatibility for APIs that PyTorch is expected to add to
torch.accelerator in future releases but are not yet present (Stream,
Event, manual_seed, get_device_name, ...). Any attribute missing from
the current torch.accelerator module is looked up on torch.musa instead.

3. device_index(idx) and stream(s) context managers, which are not yet
4. device_index(idx) and stream(s) context managers, which are not yet
present on torch.accelerator in torch 2.7.

TODO(torchada): README.md / README_CN.md claim "the wrapper always prefers
the real torch.accelerator implementation and only falls back to torch.musa
when an attribute is missing". That is no longer accurate after adding the
memory API overrides (point 2 above). Update those documents to describe
the actual resolution order: (1) torchada overrides, (2) real torch.accelerator,
(3) fallback to torch.musa.
"""
global _original_torch_accelerator

Expand Down
23 changes: 19 additions & 4 deletions tests/test_cuda_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2487,17 +2487,32 @@ def _make_wrapper(self, accel_attrs=None, musa_attrs=None):
return _AcceleratorModuleWrapper(accel, musa), accel, musa

def test_original_accelerator_takes_precedence_over_musa(self):
"""Official torch.accelerator implementations must win over torch.musa.
"""Official torch.accelerator implementations win over torch.musa for non-overridden APIs.
This is the forward-compat guarantee: when PyTorch 2.9+ adds an
official implementation of an API (e.g. empty_cache), the wrapper
must return the official one, not the torch.musa fallback.
official implementation of an API (e.g. manual_seed) that is NOT in
_MUSA_OVERRIDES, the wrapper must return the official one, not the
torch.musa fallback.
"""
wrapper, _, _ = self._make_wrapper(
accel_attrs={"manual_seed": "official_impl"},
musa_attrs={"manual_seed": "musa_fallback"},
)
assert wrapper.manual_seed == "official_impl"

def test_musa_overrides_take_precedence_when_both_exist(self):
"""Memory APIs in _MUSA_OVERRIDES use torch.musa even when torch.accelerator has them.
Starting in PyTorch 2.9+, torch.accelerator.empty_cache() exists but
routes through torch._C._accelerator_* which doesn't work with the MUSA
allocator. The wrapper must override it to use torch.musa.empty_cache().
"""
wrapper, _, _ = self._make_wrapper(
accel_attrs={"empty_cache": "official_impl"},
musa_attrs={"empty_cache": "musa_fallback"},
)
assert wrapper.empty_cache == "official_impl"
# empty_cache is in _MUSA_OVERRIDES, so torch.musa wins
assert wrapper.empty_cache == "musa_fallback"

def test_fallback_to_musa_when_accelerator_missing(self):
"""Attributes absent from torch.accelerator must fall back to torch.musa."""
Expand Down