Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra
# torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it
# will be required to pass the indexing argument."
# Thus always pass it explicitly.
Comment thread
ev-br marked this conversation as resolved.
return torch.meshgrid(*arrays, indexing=indexing)
if indexing not in ("xy", "ij"):
raise ValueError(f'torch.meshgrid: indexing must be one of "xy" or "ij", but received: {indexing}')
return torch.meshgrid(*arrays, indexing=indexing) if arrays else ()
Comment thread
ev-br marked this conversation as resolved.


__all__ = ['asarray', 'result_type', 'can_cast',
Expand Down
4 changes: 3 additions & 1 deletion tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def apply_clip_compat(a):


def test_meshgrid():
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy', and supports passing no arrays."""

x, y = xp.asarray([1, 2]), xp.asarray([4])

Expand All @@ -142,6 +142,8 @@ def test_meshgrid():
assert Y.shape == Y_ij.shape
assert xp.all(Y == Y_ij)

assert not xp.meshgrid()


def test_argsort_stable():
"""Verify that argsort defaults to a stable sort."""
Expand Down
Loading