diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 167c7c0f..8792aa2e 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -489,10 +489,7 @@ def tensordot(x1: ndarray, return xp.tensordot(x1, x2, axes=axes, **kwargs) def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: + if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") if hasattr(xp, 'broadcast_tensors'): @@ -500,9 +497,9 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: else: _broadcast = xp.broadcast_arrays - x1_, x2_ = _broadcast(x1, x2) - x1_ = xp.moveaxis(x1_, axis, -1) - x2_ = xp.moveaxis(x2_, axis, -1) + x1_ = xp.moveaxis(x1, axis, -1) + x2_ = xp.moveaxis(x2, axis, -1) + x1_, x2_ = _broadcast(x1_, x2_) res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 63f0135b..7e7e2415 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -5,7 +5,8 @@ import torch array = torch.Tensor from torch import dtype as Dtype - from typing import Optional + from typing import Optional, Union, Tuple, Literal + inf = float('inf') from ._aliases import _fix_promotion, sum @@ -23,8 +24,16 @@ # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 + +# torch.cross also does not support broadcasting when it would add new +# dimensions https://github.com/pytorch/pytorch/issues/39656 def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): + raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") + if not (x1.shape[axis] == x2.shape[axis] == 3): + raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") + x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: @@ -32,19 +41,18 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension + if x1.shape[axis] != x2.shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + # torch.linalg.vecdot doesn't support integer dtypes if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): if kwargs: raise RuntimeError("vecdot kwargs not supported for integral dtypes") - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - x1_, x2_ = torch.broadcast_tensors(x1, x2) - x1_ = torch.moveaxis(x1_, axis, -1) - x2_ = torch.moveaxis(x2_, axis, -1) + x1_ = torch.moveaxis(x1, axis, -1) + x2_ = torch.moveaxis(x2, axis, -1) + x1_, x2_ = torch.broadcast_tensors(x1_, x2_) res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] @@ -59,8 +67,22 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) -__all__ = linalg_all + ['outer', 'trace', 'matmul', 'matrix_transpose', 'tensordot', - 'vecdot', 'solve'] +def vector_norm( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Union[int, float, Literal[inf, -inf]] = 2, + **kwargs, +) -> array: + # torch.vector_norm incorrectly treats axis=() the same as axis=None + if axis == (): + keepdims = True + return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) + +__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] _all_ignore = ['torch_linalg', 'sum']