Skip to content

Commit 072d45b

Browse files
committed
Add matmul wrapper to torch linalg
1 parent f4657a8 commit 072d45b

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

array_api_compat/torch/linalg.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
# outer is implemented in torch but aren't in the linalg namespace
2020
from torch import outer
21-
from ._aliases import matrix_transpose, tensordot
21+
# These functions are in both the main and linalg namespaces
22+
from ._aliases import matmul, matrix_transpose, tensordot
2223

2324
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
2425
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
@@ -58,7 +59,7 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
5859
# Use our wrapped sum to make sure it does upcasting correctly
5960
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
6061

61-
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
62+
__all__ = linalg_all + ['outer', 'trace', 'matmul', 'matrix_transpose', 'tensordot',
6263
'vecdot', 'solve']
6364

6465
_all_ignore = ['torch_linalg', 'sum']

0 commit comments

Comments
 (0)