Skip to content

Commit d9c3646

Browse files
committed
BUG: torch/meshgrid: stop ignoring the "indexing" argument
1 parent 6c708d1 commit d9c3646

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def sign(x: Array, /) -> Array:
826826
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]:
827827
# enforce the default of 'xy'
828828
# TODO: is the return type a list or a tuple
829-
return list(torch.meshgrid(*arrays, indexing='xy'))
829+
return list(torch.meshgrid(*arrays, indexing=indexing))
830830

831831

832832
__all__ = ['asarray', 'result_type', 'can_cast',

tests/test_torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,16 @@ def test_meshgrid():
117117

118118
assert Y.shape == Y_xy.shape
119119
assert xp.all(Y == Y_xy)
120+
121+
# repeat with an explicit indexing
122+
X, Y = xp.meshgrid(x, y, indexing='ij')
123+
124+
# output of torch.meshgrid(x, y, indexing='ij')
125+
X_ij, Y_ij = xp.asarray([[1], [2]]), xp.asarray([[4], [4]])
126+
127+
assert X.shape == X_ij.shape
128+
assert xp.all(X == X_ij)
129+
130+
assert Y.shape == Y_ij.shape
131+
assert xp.all(Y == Y_ij)
132+

0 commit comments

Comments
 (0)