Skip to content

Commit 93ce826

Browse files
committed
Fix numpy vecdot to apply axis before broadcasting
This is changed in the 2023 version of the spec, and matches the new np.vecdot gufunc.
1 parent 45a8e27 commit 93ce826

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

array_api_compat/common/_aliases.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -489,20 +489,17 @@ def tensordot(x1: ndarray,
489489
return xp.tensordot(x1, x2, axes=axes, **kwargs)
490490

491491
def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
492-
ndim = max(x1.ndim, x2.ndim)
493-
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
494-
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
495-
if x1_shape[axis] != x2_shape[axis]:
492+
if x1.shape[axis] != x2.shape[axis]:
496493
raise ValueError("x1 and x2 must have the same size along the given axis")
497494

498495
if hasattr(xp, 'broadcast_tensors'):
499496
_broadcast = xp.broadcast_tensors
500497
else:
501498
_broadcast = xp.broadcast_arrays
502499

503-
x1_, x2_ = _broadcast(x1, x2)
504-
x1_ = xp.moveaxis(x1_, axis, -1)
505-
x2_ = xp.moveaxis(x2_, axis, -1)
500+
x1_ = xp.moveaxis(x1, axis, -1)
501+
x2_ = xp.moveaxis(x2, axis, -1)
502+
x1_, x2_ = _broadcast(x1_, x2_)
506503

507504
res = x1_[..., None, :] @ x2_[..., None]
508505
return res[..., 0, 0]

0 commit comments

Comments
 (0)