From 6948a4acd7ed4538951eb66f15c9a71037500915 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 22 Feb 2024 16:26:01 -0700 Subject: [PATCH 1/5] Fix torch.linalg.vecdot --- array_api_compat/torch/_aliases.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 23cd5219..691b6ddb 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,7 +1,7 @@ from __future__ import annotations -from builtins import all as builtin_all -from builtins import any as builtin_any +from builtins import (all as builtin_all, any as builtin_any, min as + builtin_min, max as builtin_max) from functools import wraps from typing import TYPE_CHECKING @@ -712,7 +712,7 @@ def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): if kwargs: raise RuntimeError("vecdot kwargs not supported for integral dtypes") - ndim = max(x1.ndim, x2.ndim) + ndim = builtin_max(x1.ndim, x2.ndim) x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) if x1_shape[axis] != x2_shape[axis]: @@ -733,4 +733,4 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: # torch.trace doesn't support the offset argument and doesn't support stacking def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: # Use our wrapped sum to make sure it does upcasting correctly - return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) \ No newline at end of file + return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) From a33d9ff123a18d9ee7b651d352b7a63c379746e3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 22 Feb 2024 16:26:17 -0700 Subject: [PATCH 2/5] Fix torch.cross wrapper It was not properly imported, and also torch.cross does not broadcast correctly in all cases. --- array_api_compat/torch/_aliases.py | 10 ++++++++-- array_api_compat/torch/linalg.py | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 691b6ddb..5a9bb210 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -695,12 +695,18 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) - - # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 + +# torch.cross also does not support broadcasting when it would add new +# dimensions https://github.com/pytorch/pytorch/issues/39656 def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + if not (-builtin_min(x1.ndim, x2.ndim) <= axis < builtin_max(x1.ndim, x2.ndim)): + raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") + if not (x1.shape[axis] == x2.shape[axis] == 3): + raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") + x1, x2 = torch.broadcast_tensors(x1, x2) return torch.linalg.cross(x1, x2, dim=axis) def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 160f074b..d6631cf0 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -11,6 +11,7 @@ outer = _torch.outer from ._aliases import ( # noqa: E402 + cross, matrix_transpose, solve, sum, @@ -24,6 +25,7 @@ __all__ += _torch_linalg_all __all__ += [ + 'cross', "matrix_transpose", "outer", "solve", From 45a8e27b2b66200aabb17e5830b7bf273adc92d8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 26 Feb 2024 16:09:36 -0700 Subject: [PATCH 3/5] Update the torch vecdot wrapper - torch.vecdot incorrectly allows broadcasting in the contracted dimensions - The 2023 version of the spec updates the language to require axis to apply before broadcasting, not after. The implementation for integer arguments is updated to follow this behavior. Note that the spec only actually requires axis to be negative, but we allow nonnegative axis too if x1 and x2 have the required number of dimensions and those dimensions have same value, which matches the numpy gufunc vecdot implementation and also the base torch.vecdot implementation as far as I can tell. --- array_api_compat/torch/linalg.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index cbccacf3..87205739 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -40,19 +40,18 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension + if x1.shape[axis] != x2.shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + # torch.linalg.vecdot doesn't support integer dtypes if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): if kwargs: raise RuntimeError("vecdot kwargs not supported for integral dtypes") - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - x1_, x2_ = torch.broadcast_tensors(x1, x2) - x1_ = torch.moveaxis(x1_, axis, -1) - x2_ = torch.moveaxis(x2_, axis, -1) + + x1_ = torch.moveaxis(x1, axis, -1) + x2_ = torch.moveaxis(x2, axis, -1) + x1_, x2_ = torch.broadcast_tensors(x1_, x2_) res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] From 93ce82683eae79480c188332c3ee4e9920180487 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 27 Feb 2024 13:20:14 -0700 Subject: [PATCH 4/5] Fix numpy vecdot to apply axis before broadcasting This is changed in the 2023 version of the spec, and matches the new np.vecdot gufunc. --- array_api_compat/common/_aliases.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 167c7c0f..8792aa2e 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -489,10 +489,7 @@ def tensordot(x1: ndarray, return xp.tensordot(x1, x2, axes=axes, **kwargs) def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: + if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") if hasattr(xp, 'broadcast_tensors'): @@ -500,9 +497,9 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: else: _broadcast = xp.broadcast_arrays - x1_, x2_ = _broadcast(x1, x2) - x1_ = xp.moveaxis(x1_, axis, -1) - x2_ = xp.moveaxis(x2_, axis, -1) + x1_ = xp.moveaxis(x1, axis, -1) + x2_ = xp.moveaxis(x2, axis, -1) + x1_, x2_ = _broadcast(x1_, x2_) res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] From 31c94bb8e88c2e404fba185ee976ec9e7c2f7441 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 27 Feb 2024 13:29:06 -0700 Subject: [PATCH 5/5] Fix a test failure with torch.vector_norm Also cleanup the torch.linalg __all__ --- array_api_compat/torch/linalg.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 87205739..7e7e2415 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -5,7 +5,8 @@ import torch array = torch.Tensor from torch import dtype as Dtype - from typing import Optional + from typing import Optional, Union, Tuple, Literal + inf = float('inf') from ._aliases import _fix_promotion, sum @@ -66,8 +67,22 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) -__all__ = linalg_all + ['outer', 'trace', 'matmul', 'matrix_transpose', 'tensordot', - 'vecdot', 'solve'] +def vector_norm( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Union[int, float, Literal[inf, -inf]] = 2, + **kwargs, +) -> array: + # torch.vector_norm incorrectly treats axis=() the same as axis=None + if axis == (): + keepdims = True + return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) + +__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] _all_ignore = ['torch_linalg', 'sum']