diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 5969e4db..e27c3ca2 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -936,7 +936,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra # torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it # will be required to pass the indexing argument." # Thus always pass it explicitly. - return torch.meshgrid(*arrays, indexing=indexing) + if indexing not in ("xy", "ij"): + raise ValueError(f'torch.meshgrid: indexing must be one of "xy" or "ij", but received: {indexing}') + return torch.meshgrid(*arrays, indexing=indexing) if arrays else () __all__ = ['asarray', 'result_type', 'can_cast', diff --git a/tests/test_torch.py b/tests/test_torch.py index b064a46d..3d6ebc46 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -115,7 +115,7 @@ def apply_clip_compat(a): def test_meshgrid(): - """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'.""" + """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy', and supports passing no arrays.""" x, y = xp.asarray([1, 2]), xp.asarray([4]) @@ -142,6 +142,8 @@ def test_meshgrid(): assert Y.shape == Y_ij.shape assert xp.all(Y == Y_ij) + assert not xp.meshgrid() + def test_argsort_stable(): """Verify that argsort defaults to a stable sort."""