Skip to content

Commit 6e3684b

Browse files
committed
Fix function annotation parsing for aliased linear algebra functions
1 parent c907e8b commit 6e3684b

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

array_api_tests/function_stubs/linalg.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from ._types import Literal, Optional, Tuple, Union, array
1414
from .constants import inf
15+
from collections.abc import Sequence
1516

1617
def cholesky(x: array, /, *, upper: bool = False) -> array:
1718
pass
@@ -46,7 +47,7 @@ def inv(x: array, /) -> array:
4647
def lstsq(x1: array, x2: array, /, *, rtol: Optional[Union[float, array]] = None) -> Tuple[array, array, array, array]:
4748
pass
4849

49-
def matmul(x1, x2, /):
50+
def matmul(x1: array, x2: array, /) -> array:
5051
pass
5152

5253
def matrix_power(x: array, n: int, /) -> array:
@@ -76,7 +77,7 @@ def solve(x1: array, x2: array, /) -> array:
7677
def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]:
7778
pass
7879

79-
def tensordot(x1, x2, /, *, axes=2):
80+
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
8081
pass
8182

8283
def svdvals(x: array, /) -> Union[array, Tuple[array, ...]]:
@@ -85,10 +86,10 @@ def svdvals(x: array, /) -> Union[array, Tuple[array, ...]]:
8586
def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
8687
pass
8788

88-
def transpose(x, /, *, axes=None):
89+
def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
8990
pass
9091

91-
def vecdot(x1, x2, /, *, axis=None):
92+
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
9293
pass
9394

9495
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'tensordot', 'svdvals', 'trace', 'transpose', 'vecdot']

generate_stubs.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ATTRIBUTE_RE = regex.compile(r'\(attribute-.*\)=\n#+ ?(.*)')
2828
IN_PLACE_OPERATOR_RE = regex.compile(r'- `.*`. May be implemented via `__i(.*)__`.')
2929
REFLECTED_OPERATOR_RE = regex.compile(r'- `__r(.*)__`')
30+
ALIAS_RE = regex.compile(r'Alias for {ref}`function-(.*)`.')
3031

3132
NAME_RE = regex.compile(r'(.*)\(.*\)')
3233

@@ -120,6 +121,7 @@ def main():
120121
files = sorted([os.path.join(spec_dir, f) for f in os.listdir(spec_dir)]
121122
+ [os.path.join(extensions_dir, f) for f in os.listdir(extensions_dir)])
122123
modules = {}
124+
all_annotations = {}
123125
for file in files:
124126
filename = os.path.basename(file)
125127
with open(file) as f:
@@ -151,7 +153,8 @@ def main():
151153
if not args.quiet:
152154
print(f"Writing {py_path}")
153155

154-
annotations = parse_annotations(text, verbose=not args.quiet)
156+
annotations = parse_annotations(text, all_annotations, verbose=not args.quiet)
157+
all_annotations.update(annotations)
155158

156159
if filename == 'array_object.md':
157160
in_place_operators = IN_PLACE_OPERATOR_RE.findall(text)
@@ -654,7 +657,7 @@ def parse_special_cases(spec_text, verbose=False):
654657
return special_cases
655658

656659
PARAMETER_RE = regex.compile(r"- +\*\*(.*)\*\*: _(.*)_")
657-
def parse_annotations(spec_text, verbose=False):
660+
def parse_annotations(spec_text, all_annotations, verbose=False):
658661
annotations = defaultdict(dict)
659662
in_block = False
660663
is_returns = False
@@ -663,6 +666,14 @@ def parse_annotations(spec_text, verbose=False):
663666
if m:
664667
name = m.group(1)
665668
continue
669+
m = ALIAS_RE.match(line)
670+
if m:
671+
alias_name = m.group(1)
672+
if alias_name not in all_annotations:
673+
print(f"Warning: No annotations for aliased function {name}")
674+
else:
675+
annotations[name] = all_annotations[m.group(1)]
676+
continue
666677
if line == '#### Parameters':
667678
in_block = True
668679
continue

0 commit comments

Comments
 (0)