We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 45a8e27 commit 93ce826Copy full SHA for 93ce826
array_api_compat/common/_aliases.py
@@ -489,20 +489,17 @@ def tensordot(x1: ndarray,
489
return xp.tensordot(x1, x2, axes=axes, **kwargs)
490
491
def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
492
- ndim = max(x1.ndim, x2.ndim)
493
- x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
494
- x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
495
- if x1_shape[axis] != x2_shape[axis]:
+ if x1.shape[axis] != x2.shape[axis]:
496
raise ValueError("x1 and x2 must have the same size along the given axis")
497
498
if hasattr(xp, 'broadcast_tensors'):
499
_broadcast = xp.broadcast_tensors
500
else:
501
_broadcast = xp.broadcast_arrays
502
503
- x1_, x2_ = _broadcast(x1, x2)
504
- x1_ = xp.moveaxis(x1_, axis, -1)
505
- x2_ = xp.moveaxis(x2_, axis, -1)
+ x1_ = xp.moveaxis(x1, axis, -1)
+ x2_ = xp.moveaxis(x2, axis, -1)
+ x1_, x2_ = _broadcast(x1_, x2_)
506
507
res = x1_[..., None, :] @ x2_[..., None]
508
return res[..., 0, 0]
0 commit comments