Skip to content

Commit 6948a4a

Browse files
committed
Fix torch.linalg.vecdot
1 parent ab74e4a commit 6948a4a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from builtins import all as builtin_all
4-
from builtins import any as builtin_any
3+
from builtins import (all as builtin_all, any as builtin_any, min as
4+
builtin_min, max as builtin_max)
55
from functools import wraps
66
from typing import TYPE_CHECKING
77

@@ -712,7 +712,7 @@ def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array
712712
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
713713
if kwargs:
714714
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
715-
ndim = max(x1.ndim, x2.ndim)
715+
ndim = builtin_max(x1.ndim, x2.ndim)
716716
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
717717
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
718718
if x1_shape[axis] != x2_shape[axis]:
@@ -733,4 +733,4 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
733733
# torch.trace doesn't support the offset argument and doesn't support stacking
734734
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
735735
# Use our wrapped sum to make sure it does upcasting correctly
736-
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
736+
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)

0 commit comments

Comments
 (0)