5
5
import torch
6
6
array = torch .Tensor
7
7
from torch import dtype as Dtype
8
- from typing import Optional
8
+ from typing import Optional , Union , Tuple , Literal
9
+ inf = float ('inf' )
9
10
10
11
from ._aliases import _fix_promotion , sum
11
12
23
24
24
25
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
25
26
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
27
+
28
+ # torch.cross also does not support broadcasting when it would add new
29
+ # dimensions https://github.com/pytorch/pytorch/issues/39656
26
30
def cross (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
27
31
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
32
+ if not (- min (x1 .ndim , x2 .ndim ) <= axis < max (x1 .ndim , x2 .ndim )):
33
+ raise ValueError (f"axis { axis } out of bounds for cross product of arrays with shapes { x1 .shape } and { x2 .shape } " )
34
+ if not (x1 .shape [axis ] == x2 .shape [axis ] == 3 ):
35
+ raise ValueError (f"cross product axis must have size 3, got { x1 .shape [axis ]} and { x2 .shape [axis ]} " )
36
+ x1 , x2 = torch .broadcast_tensors (x1 , x2 )
28
37
return torch_linalg .cross (x1 , x2 , dim = axis )
29
38
30
39
def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
31
40
from ._aliases import isdtype
32
41
33
42
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
34
43
44
+ # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
45
+ if x1 .shape [axis ] != x2 .shape [axis ]:
46
+ raise ValueError ("x1 and x2 must have the same size along the given axis" )
47
+
35
48
# torch.linalg.vecdot doesn't support integer dtypes
36
49
if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
37
50
if kwargs :
38
51
raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
39
- ndim = max (x1 .ndim , x2 .ndim )
40
- x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
41
- x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
42
- if x1_shape [axis ] != x2_shape [axis ]:
43
- raise ValueError ("x1 and x2 must have the same size along the given axis" )
44
52
45
- x1_ , x2_ = torch .broadcast_tensors (x1 , x2 )
46
- x1_ = torch .moveaxis (x1_ , axis , - 1 )
47
- x2_ = torch .moveaxis ( x2_ , axis , - 1 )
53
+ x1_ = torch .moveaxis (x1 , axis , - 1 )
54
+ x2_ = torch .moveaxis (x2 , axis , - 1 )
55
+ x1_ , x2_ = torch .broadcast_tensors ( x1_ , x2_ )
48
56
49
57
res = x1_ [..., None , :] @ x2_ [..., None ]
50
58
return res [..., 0 , 0 ]
@@ -59,8 +67,22 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
59
67
# Use our wrapped sum to make sure it does upcasting correctly
60
68
return sum (torch .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
61
69
62
- __all__ = linalg_all + ['outer' , 'trace' , 'matmul' , 'matrix_transpose' , 'tensordot' ,
63
- 'vecdot' , 'solve' ]
70
+ def vector_norm (
71
+ x : array ,
72
+ / ,
73
+ * ,
74
+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
75
+ keepdims : bool = False ,
76
+ ord : Union [int , float , Literal [inf , - inf ]] = 2 ,
77
+ ** kwargs ,
78
+ ) -> array :
79
+ # torch.vector_norm incorrectly treats axis=() the same as axis=None
80
+ if axis == ():
81
+ keepdims = True
82
+ return torch .linalg .vector_norm (x , ord = ord , axis = axis , keepdim = keepdims , ** kwargs )
83
+
84
+ __all__ = linalg_all + ['outer' , 'matmul' , 'matrix_transpose' , 'tensordot' ,
85
+ 'cross' , 'vecdot' , 'solve' , 'trace' , 'vector_norm' ]
64
86
65
87
_all_ignore = ['torch_linalg' , 'sum' ]
66
88
0 commit comments