1
1
from __future__ import annotations
2
2
3
- from builtins import all as builtin_all
4
- from builtins import any as builtin_any
3
+ from builtins import ( all as builtin_all , any as builtin_any , min as
4
+ builtin_min , max as builtin_max )
5
5
from functools import wraps
6
6
from typing import TYPE_CHECKING
7
7
@@ -712,7 +712,7 @@ def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array
712
712
if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
713
713
if kwargs :
714
714
raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
715
- ndim = max (x1 .ndim , x2 .ndim )
715
+ ndim = builtin_max (x1 .ndim , x2 .ndim )
716
716
x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
717
717
x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
718
718
if x1_shape [axis ] != x2_shape [axis ]:
@@ -733,4 +733,4 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
733
733
# torch.trace doesn't support the offset argument and doesn't support stacking
734
734
def trace (x : array , / , * , offset : int = 0 , dtype : Optional [Dtype ] = None ) -> array :
735
735
# Use our wrapped sum to make sure it does upcasting correctly
736
- return sum (torch .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
736
+ return sum (torch .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
0 commit comments