Skip to content

Commit a33d9ff

Browse files
committed
Fix torch.cross wrapper
It was not properly imported, and also torch.cross does not broadcast correctly in all cases.
1 parent 6948a4a commit a33d9ff

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,12 +695,18 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
695695
axis = 0
696696
return torch.index_select(x, axis, indices, **kwargs)
697697

698-
699-
700698
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
701699
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
700+
701+
# torch.cross also does not support broadcasting when it would add new
702+
# dimensions https://github.com/pytorch/pytorch/issues/39656
702703
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
703704
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
705+
if not (-builtin_min(x1.ndim, x2.ndim) <= axis < builtin_max(x1.ndim, x2.ndim)):
706+
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
707+
if not (x1.shape[axis] == x2.shape[axis] == 3):
708+
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
709+
x1, x2 = torch.broadcast_tensors(x1, x2)
704710
return torch.linalg.cross(x1, x2, dim=axis)
705711

706712
def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:

array_api_compat/torch/linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
outer = _torch.outer
1212

1313
from ._aliases import ( # noqa: E402
14+
cross,
1415
matrix_transpose,
1516
solve,
1617
sum,
@@ -24,6 +25,7 @@
2425
__all__ += _torch_linalg_all
2526

2627
__all__ += [
28+
'cross',
2729
"matrix_transpose",
2830
"outer",
2931
"solve",

0 commit comments

Comments
 (0)