Skip to content

Commit 45a8e27

Browse files
committed
Update the torch vecdot wrapper
- torch.vecdot incorrectly allows broadcasting in the contracted dimensions - The 2023 version of the spec updates the language to require axis to apply before broadcasting, not after. The implementation for integer arguments is updated to follow this behavior. Note that the spec only actually requires axis to be negative, but we allow nonnegative axis too if x1 and x2 have the required number of dimensions and those dimensions have same value, which matches the numpy gufunc vecdot implementation and also the base torch.vecdot implementation as far as I can tell.
1 parent e4ec81d commit 45a8e27

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

array_api_compat/torch/linalg.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,18 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
4040

4141
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
4242

43+
# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
44+
if x1.shape[axis] != x2.shape[axis]:
45+
raise ValueError("x1 and x2 must have the same size along the given axis")
46+
4347
# torch.linalg.vecdot doesn't support integer dtypes
4448
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
4549
if kwargs:
4650
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
47-
ndim = max(x1.ndim, x2.ndim)
48-
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
49-
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
50-
if x1_shape[axis] != x2_shape[axis]:
51-
raise ValueError("x1 and x2 must have the same size along the given axis")
52-
53-
x1_, x2_ = torch.broadcast_tensors(x1, x2)
54-
x1_ = torch.moveaxis(x1_, axis, -1)
55-
x2_ = torch.moveaxis(x2_, axis, -1)
51+
52+
x1_ = torch.moveaxis(x1, axis, -1)
53+
x2_ = torch.moveaxis(x2, axis, -1)
54+
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
5655

5756
res = x1_[..., None, :] @ x2_[..., None]
5857
return res[..., 0, 0]

0 commit comments

Comments
 (0)