@@ -695,12 +695,18 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
695
695
axis = 0
696
696
return torch .index_select (x , axis , indices , ** kwargs )
697
697
698
-
699
-
700
698
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
701
699
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
700
+
701
+ # torch.cross also does not support broadcasting when it would add new
702
+ # dimensions https://github.com/pytorch/pytorch/issues/39656
702
703
def cross (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
703
704
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
705
+ if not (- builtin_min (x1 .ndim , x2 .ndim ) <= axis < builtin_max (x1 .ndim , x2 .ndim )):
706
+ raise ValueError (f"axis { axis } out of bounds for cross product of arrays with shapes { x1 .shape } and { x2 .shape } " )
707
+ if not (x1 .shape [axis ] == x2 .shape [axis ] == 3 ):
708
+ raise ValueError (f"cross product axis must have size 3, got { x1 .shape [axis ]} and { x2 .shape [axis ]} " )
709
+ x1 , x2 = torch .broadcast_tensors (x1 , x2 )
704
710
return torch .linalg .cross (x1 , x2 , dim = axis )
705
711
706
712
def vecdot_linalg (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
0 commit comments