Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes for torch.linalg wrappers #94

Merged
merged 6 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,20 +489,17 @@ def tensordot(x1: ndarray,
return xp.tensordot(x1, x2, axes=axes, **kwargs)

def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

if hasattr(xp, 'broadcast_tensors'):
_broadcast = xp.broadcast_tensors
else:
_broadcast = xp.broadcast_arrays

x1_, x2_ = _broadcast(x1, x2)
x1_ = xp.moveaxis(x1_, axis, -1)
x2_ = xp.moveaxis(x2_, axis, -1)
x1_ = xp.moveaxis(x1, axis, -1)
x2_ = xp.moveaxis(x2, axis, -1)
x1_, x2_ = _broadcast(x1_, x2_)

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
Expand Down
44 changes: 33 additions & 11 deletions array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
array = torch.Tensor
from torch import dtype as Dtype
from typing import Optional
from typing import Optional, Union, Tuple, Literal
inf = float('inf')

from ._aliases import _fix_promotion, sum

Expand All @@ -23,28 +24,35 @@

# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743

# torch.cross also does not support broadcasting when it would add new
# dimensions https://github.com/pytorch/pytorch/issues/39656
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
x1, x2 = torch.broadcast_tensors(x1, x2)
return torch_linalg.cross(x1, x2, dim=axis)

def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype

x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

x1_, x2_ = torch.broadcast_tensors(x1, x2)
x1_ = torch.moveaxis(x1_, axis, -1)
x2_ = torch.moveaxis(x2_, axis, -1)
x1_ = torch.moveaxis(x1, axis, -1)
x2_ = torch.moveaxis(x2, axis, -1)
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
Expand All @@ -59,8 +67,22 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)

__all__ = linalg_all + ['outer', 'trace', 'matmul', 'matrix_transpose', 'tensordot',
'vecdot', 'solve']
def vector_norm(
x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
ord: Union[int, float, Literal[inf, -inf]] = 2,
**kwargs,
) -> array:
# torch.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
keepdims = True
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)

__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']

_all_ignore = ['torch_linalg', 'sum']

Expand Down
Loading