Skip to content

Commit c907e8b

Browse files
committed
Fix parsing of function annotations for linalg extension functions
1 parent cc6ac97 commit c907e8b

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

array_api_tests/function_stubs/linalg.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -10,77 +10,79 @@
1010

1111
from __future__ import annotations
1212

13+
from ._types import Literal, Optional, Tuple, Union, array
14+
from .constants import inf
1315

14-
def cholesky(x, /, *, upper=False):
16+
def cholesky(x: array, /, *, upper: bool = False) -> array:
1517
pass
1618

17-
def cross(x1, x2, /, *, axis=-1):
19+
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
1820
pass
1921

20-
def det(x, /):
22+
def det(x: array, /) -> array:
2123
pass
2224

23-
def diagonal(x, /, *, axis1=0, axis2=1, offset=0):
25+
def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
2426
pass
2527

2628
def eig():
2729
pass
2830

29-
def eigh(x, /, *, upper=False):
31+
def eigh(x: array, /, *, upper: bool = False) -> Tuple[array]:
3032
pass
3133

3234
def eigvals():
3335
pass
3436

35-
def eigvalsh(x, /, *, upper=False):
37+
def eigvalsh(x: array, /, *, upper: bool = False) -> array:
3638
pass
3739

3840
def einsum():
3941
pass
4042

41-
def inv(x, /):
43+
def inv(x: array, /) -> array:
4244
pass
4345

44-
def lstsq(x1, x2, /, *, rtol=None):
46+
def lstsq(x1: array, x2: array, /, *, rtol: Optional[Union[float, array]] = None) -> Tuple[array, array, array, array]:
4547
pass
4648

4749
def matmul(x1, x2, /):
4850
pass
4951

50-
def matrix_power(x, n, /):
52+
def matrix_power(x: array, n: int, /) -> array:
5153
pass
5254

53-
def matrix_rank(x, /, *, rtol=None):
55+
def matrix_rank(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
5456
pass
5557

56-
def norm(x, /, *, axis=None, keepdims=False, ord=None):
58+
def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[inf, -inf, 'fro', 'nuc']]] = None) -> array:
5759
pass
5860

59-
def outer(x1, x2, /):
61+
def outer(x1: array, x2: array, /) -> array:
6062
pass
6163

62-
def pinv(x, /, *, rtol=None):
64+
def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
6365
pass
6466

65-
def qr(x, /, *, mode='reduced'):
67+
def qr(x: array, /, *, mode: str = 'reduced') -> Tuple[array, array]:
6668
pass
6769

68-
def slogdet(x, /):
70+
def slogdet(x: array, /) -> Tuple[array, array]:
6971
pass
7072

71-
def solve(x1, x2, /):
73+
def solve(x1: array, x2: array, /) -> array:
7274
pass
7375

74-
def svd(x, /, *, full_matrices=True):
76+
def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]:
7577
pass
7678

7779
def tensordot(x1, x2, /, *, axes=2):
7880
pass
7981

80-
def svdvals(x, /):
82+
def svdvals(x: array, /) -> Union[array, Tuple[array, ...]]:
8183
pass
8284

83-
def trace(x, /, *, axis1=0, axis2=1, offset=0):
85+
def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
8486
pass
8587

8688
def transpose(x, /, *, axes=None):

generate_stubs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from removestar.removestar import fix_code
2121

2222
FUNCTION_HEADER_RE = regex.compile(r'\(function-(.*?)\)')
23-
HEADER_RE = regex.compile(r'\((?:function|method|constant|attribute)-(.*?)\)')
23+
HEADER_RE = regex.compile(r'\((?:function-linalg|function|method|constant|attribute)-(.*?)\)')
2424
FUNCTION_RE = regex.compile(r'\(function-.*\)=\n#+ ?(.*\(.*\))')
2525
METHOD_RE = regex.compile(r'\(method-.*\)=\n#+ ?(.*\(.*\))')
2626
CONSTANT_RE = regex.compile(r'\(constant-.*\)=\n#+ ?(.*)')

0 commit comments

Comments
 (0)