Skip to content

Commit 74b7b79

Browse files
authored
Merge pull request #94 from asmeurer/torch-linalg-fixes
Some fixes for torch.linalg wrappers
2 parents 8657396 + 31c94bb commit 74b7b79

File tree

2 files changed

+37
-18
lines changed

2 files changed

+37
-18
lines changed

array_api_compat/common/_aliases.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -489,20 +489,17 @@ def tensordot(x1: ndarray,
489489
return xp.tensordot(x1, x2, axes=axes, **kwargs)
490490

491491
def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
492-
ndim = max(x1.ndim, x2.ndim)
493-
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
494-
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
495-
if x1_shape[axis] != x2_shape[axis]:
492+
if x1.shape[axis] != x2.shape[axis]:
496493
raise ValueError("x1 and x2 must have the same size along the given axis")
497494

498495
if hasattr(xp, 'broadcast_tensors'):
499496
_broadcast = xp.broadcast_tensors
500497
else:
501498
_broadcast = xp.broadcast_arrays
502499

503-
x1_, x2_ = _broadcast(x1, x2)
504-
x1_ = xp.moveaxis(x1_, axis, -1)
505-
x2_ = xp.moveaxis(x2_, axis, -1)
500+
x1_ = xp.moveaxis(x1, axis, -1)
501+
x2_ = xp.moveaxis(x2, axis, -1)
502+
x1_, x2_ = _broadcast(x1_, x2_)
506503

507504
res = x1_[..., None, :] @ x2_[..., None]
508505
return res[..., 0, 0]

array_api_compat/torch/linalg.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import torch
66
array = torch.Tensor
77
from torch import dtype as Dtype
8-
from typing import Optional
8+
from typing import Optional, Union, Tuple, Literal
9+
inf = float('inf')
910

1011
from ._aliases import _fix_promotion, sum
1112

@@ -23,28 +24,35 @@
2324

2425
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
2526
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
27+
28+
# torch.cross also does not support broadcasting when it would add new
29+
# dimensions https://github.com/pytorch/pytorch/issues/39656
2630
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
2731
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
32+
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
33+
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
34+
if not (x1.shape[axis] == x2.shape[axis] == 3):
35+
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
36+
x1, x2 = torch.broadcast_tensors(x1, x2)
2837
return torch_linalg.cross(x1, x2, dim=axis)
2938

3039
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
3140
from ._aliases import isdtype
3241

3342
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
3443

44+
# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
45+
if x1.shape[axis] != x2.shape[axis]:
46+
raise ValueError("x1 and x2 must have the same size along the given axis")
47+
3548
# torch.linalg.vecdot doesn't support integer dtypes
3649
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
3750
if kwargs:
3851
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
39-
ndim = max(x1.ndim, x2.ndim)
40-
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
41-
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
42-
if x1_shape[axis] != x2_shape[axis]:
43-
raise ValueError("x1 and x2 must have the same size along the given axis")
4452

45-
x1_, x2_ = torch.broadcast_tensors(x1, x2)
46-
x1_ = torch.moveaxis(x1_, axis, -1)
47-
x2_ = torch.moveaxis(x2_, axis, -1)
53+
x1_ = torch.moveaxis(x1, axis, -1)
54+
x2_ = torch.moveaxis(x2, axis, -1)
55+
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
4856

4957
res = x1_[..., None, :] @ x2_[..., None]
5058
return res[..., 0, 0]
@@ -59,8 +67,22 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
5967
# Use our wrapped sum to make sure it does upcasting correctly
6068
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
6169

62-
__all__ = linalg_all + ['outer', 'trace', 'matmul', 'matrix_transpose', 'tensordot',
63-
'vecdot', 'solve']
70+
def vector_norm(
71+
x: array,
72+
/,
73+
*,
74+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
75+
keepdims: bool = False,
76+
ord: Union[int, float, Literal[inf, -inf]] = 2,
77+
**kwargs,
78+
) -> array:
79+
# torch.vector_norm incorrectly treats axis=() the same as axis=None
80+
if axis == ():
81+
keepdims = True
82+
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
83+
84+
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
85+
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
6486

6587
_all_ignore = ['torch_linalg', 'sum']
6688

0 commit comments

Comments
 (0)