Skip to content

Commit 7267090

Browse files
committed
Update the latest version of the spec
Some linear algebra function stubs are removed because they are now in extensions, and we need to update the generate_stubs script to be able to parse those.
1 parent 6d16833 commit 7267090

File tree

4 files changed

+52
-103
lines changed

4 files changed

+52
-103
lines changed

array_api_tests/function_stubs/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030
__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
3131

32-
from .linear_algebra_functions import cholesky, cross, det, diagonal, dot, eig, eigvalsh, einsum, inv, lstsq, matmul, matrix_power, matrix_rank, norm, outer, pinv, qr, slogdet, solve, svd, trace, transpose
32+
from .linear_algebra_functions import einsum, matmul, tensordot, transpose, vecdot
3333

34-
__all__ += ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']
34+
__all__ += ['einsum', 'matmul', 'tensordot', 'transpose', 'vecdot']
3535

3636
from .manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack
3737

array_api_tests/function_stubs/array_object.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ def __abs__(self: array, /) -> array:
1919
"""
2020
pass
2121

22-
def __add__(self: array, other: array, /) -> array:
22+
def __add__(self: array, other: Union[int, float, array], /) -> array:
2323
"""
2424
Note: __add__ is a method of the array object.
2525
"""
2626
pass
2727

28-
def __and__(self: array, other: array, /) -> array:
28+
def __and__(self: array, other: Union[int, bool, array], /) -> array:
2929
"""
3030
Note: __and__ is a method of the array object.
3131
"""
@@ -55,7 +55,7 @@ def __dlpack_device__(self: array, /) -> Tuple[IntEnum, int]:
5555
"""
5656
pass
5757

58-
def __eq__(self: array, other: array, /) -> array:
58+
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
5959
"""
6060
Note: __eq__ is a method of the array object.
6161
"""
@@ -67,13 +67,13 @@ def __float__(self: array, /) -> float:
6767
"""
6868
pass
6969

70-
def __floordiv__(self: array, other: array, /) -> array:
70+
def __floordiv__(self: array, other: Union[int, float, array], /) -> array:
7171
"""
7272
Note: __floordiv__ is a method of the array object.
7373
"""
7474
pass
7575

76-
def __ge__(self: array, other: array, /) -> array:
76+
def __ge__(self: array, other: Union[int, float, array], /) -> array:
7777
"""
7878
Note: __ge__ is a method of the array object.
7979
"""
@@ -85,7 +85,7 @@ def __getitem__(self: array, key: Union[int, slice, ellipsis, Tuple[Union[int, s
8585
"""
8686
pass
8787

88-
def __gt__(self: array, other: array, /) -> array:
88+
def __gt__(self: array, other: Union[int, float, array], /) -> array:
8989
"""
9090
Note: __gt__ is a method of the array object.
9191
"""
@@ -103,7 +103,7 @@ def __invert__(self: array, /) -> array:
103103
"""
104104
pass
105105

106-
def __le__(self: array, other: array, /) -> array:
106+
def __le__(self: array, other: Union[int, float, array], /) -> array:
107107
"""
108108
Note: __le__ is a method of the array object.
109109
"""
@@ -115,13 +115,13 @@ def __len__(self, /):
115115
"""
116116
pass
117117

118-
def __lshift__(self: array, other: array, /) -> array:
118+
def __lshift__(self: array, other: Union[int, array], /) -> array:
119119
"""
120120
Note: __lshift__ is a method of the array object.
121121
"""
122122
pass
123123

124-
def __lt__(self: array, other: array, /) -> array:
124+
def __lt__(self: array, other: Union[int, float, array], /) -> array:
125125
"""
126126
Note: __lt__ is a method of the array object.
127127
"""
@@ -133,19 +133,19 @@ def __matmul__(self: array, other: array, /) -> array:
133133
"""
134134
pass
135135

136-
def __mod__(self: array, other: array, /) -> array:
136+
def __mod__(self: array, other: Union[int, float, array], /) -> array:
137137
"""
138138
Note: __mod__ is a method of the array object.
139139
"""
140140
pass
141141

142-
def __mul__(self: array, other: array, /) -> array:
142+
def __mul__(self: array, other: Union[int, float, array], /) -> array:
143143
"""
144144
Note: __mul__ is a method of the array object.
145145
"""
146146
pass
147147

148-
def __ne__(self: array, other: array, /) -> array:
148+
def __ne__(self: array, other: Union[int, float, bool, array], /) -> array:
149149
"""
150150
Note: __ne__ is a method of the array object.
151151
"""
@@ -157,7 +157,7 @@ def __neg__(self: array, /) -> array:
157157
"""
158158
pass
159159

160-
def __or__(self: array, other: array, /) -> array:
160+
def __or__(self: array, other: Union[int, bool, array], /) -> array:
161161
"""
162162
Note: __or__ is a method of the array object.
163163
"""
@@ -169,13 +169,13 @@ def __pos__(self: array, /) -> array:
169169
"""
170170
pass
171171

172-
def __pow__(self: array, other: array, /) -> array:
172+
def __pow__(self: array, other: Union[int, float, array], /) -> array:
173173
"""
174174
Note: __pow__ is a method of the array object.
175175
"""
176176
pass
177177

178-
def __rshift__(self: array, other: array, /) -> array:
178+
def __rshift__(self: array, other: Union[int, array], /) -> array:
179179
"""
180180
Note: __rshift__ is a method of the array object.
181181
"""
@@ -187,67 +187,67 @@ def __setitem__(self, key, value, /):
187187
"""
188188
pass
189189

190-
def __sub__(self: array, other: array, /) -> array:
190+
def __sub__(self: array, other: Union[int, float, array], /) -> array:
191191
"""
192192
Note: __sub__ is a method of the array object.
193193
"""
194194
pass
195195

196-
def __truediv__(self: array, other: array, /) -> array:
196+
def __truediv__(self: array, other: Union[int, float, array], /) -> array:
197197
"""
198198
Note: __truediv__ is a method of the array object.
199199
"""
200200
pass
201201

202-
def __xor__(self: array, other: array, /) -> array:
202+
def __xor__(self: array, other: Union[int, bool, array], /) -> array:
203203
"""
204204
Note: __xor__ is a method of the array object.
205205
"""
206206
pass
207207

208-
def __iadd__(self: array, other: array, /) -> array:
208+
def __iadd__(self: array, other: Union[int, float, array], /) -> array:
209209
"""
210210
Note: __iadd__ is a method of the array object.
211211
"""
212212
pass
213213

214-
def __radd__(self: array, other: array, /) -> array:
214+
def __radd__(self: array, other: Union[int, float, array], /) -> array:
215215
"""
216216
Note: __radd__ is a method of the array object.
217217
"""
218218
pass
219219

220-
def __iand__(self: array, other: array, /) -> array:
220+
def __iand__(self: array, other: Union[int, bool, array], /) -> array:
221221
"""
222222
Note: __iand__ is a method of the array object.
223223
"""
224224
pass
225225

226-
def __rand__(self: array, other: array, /) -> array:
226+
def __rand__(self: array, other: Union[int, bool, array], /) -> array:
227227
"""
228228
Note: __rand__ is a method of the array object.
229229
"""
230230
pass
231231

232-
def __ifloordiv__(self: array, other: array, /) -> array:
232+
def __ifloordiv__(self: array, other: Union[int, float, array], /) -> array:
233233
"""
234234
Note: __ifloordiv__ is a method of the array object.
235235
"""
236236
pass
237237

238-
def __rfloordiv__(self: array, other: array, /) -> array:
238+
def __rfloordiv__(self: array, other: Union[int, float, array], /) -> array:
239239
"""
240240
Note: __rfloordiv__ is a method of the array object.
241241
"""
242242
pass
243243

244-
def __ilshift__(self: array, other: array, /) -> array:
244+
def __ilshift__(self: array, other: Union[int, array], /) -> array:
245245
"""
246246
Note: __ilshift__ is a method of the array object.
247247
"""
248248
pass
249249

250-
def __rlshift__(self: array, other: array, /) -> array:
250+
def __rlshift__(self: array, other: Union[int, array], /) -> array:
251251
"""
252252
Note: __rlshift__ is a method of the array object.
253253
"""
@@ -265,97 +265,97 @@ def __rmatmul__(self: array, other: array, /) -> array:
265265
"""
266266
pass
267267

268-
def __imod__(self: array, other: array, /) -> array:
268+
def __imod__(self: array, other: Union[int, float, array], /) -> array:
269269
"""
270270
Note: __imod__ is a method of the array object.
271271
"""
272272
pass
273273

274-
def __rmod__(self: array, other: array, /) -> array:
274+
def __rmod__(self: array, other: Union[int, float, array], /) -> array:
275275
"""
276276
Note: __rmod__ is a method of the array object.
277277
"""
278278
pass
279279

280-
def __imul__(self: array, other: array, /) -> array:
280+
def __imul__(self: array, other: Union[int, float, array], /) -> array:
281281
"""
282282
Note: __imul__ is a method of the array object.
283283
"""
284284
pass
285285

286-
def __rmul__(self: array, other: array, /) -> array:
286+
def __rmul__(self: array, other: Union[int, float, array], /) -> array:
287287
"""
288288
Note: __rmul__ is a method of the array object.
289289
"""
290290
pass
291291

292-
def __ior__(self: array, other: array, /) -> array:
292+
def __ior__(self: array, other: Union[int, bool, array], /) -> array:
293293
"""
294294
Note: __ior__ is a method of the array object.
295295
"""
296296
pass
297297

298-
def __ror__(self: array, other: array, /) -> array:
298+
def __ror__(self: array, other: Union[int, bool, array], /) -> array:
299299
"""
300300
Note: __ror__ is a method of the array object.
301301
"""
302302
pass
303303

304-
def __ipow__(self: array, other: array, /) -> array:
304+
def __ipow__(self: array, other: Union[int, float, array], /) -> array:
305305
"""
306306
Note: __ipow__ is a method of the array object.
307307
"""
308308
pass
309309

310-
def __rpow__(self: array, other: array, /) -> array:
310+
def __rpow__(self: array, other: Union[int, float, array], /) -> array:
311311
"""
312312
Note: __rpow__ is a method of the array object.
313313
"""
314314
pass
315315

316-
def __irshift__(self: array, other: array, /) -> array:
316+
def __irshift__(self: array, other: Union[int, array], /) -> array:
317317
"""
318318
Note: __irshift__ is a method of the array object.
319319
"""
320320
pass
321321

322-
def __rrshift__(self: array, other: array, /) -> array:
322+
def __rrshift__(self: array, other: Union[int, array], /) -> array:
323323
"""
324324
Note: __rrshift__ is a method of the array object.
325325
"""
326326
pass
327327

328-
def __isub__(self: array, other: array, /) -> array:
328+
def __isub__(self: array, other: Union[int, float, array], /) -> array:
329329
"""
330330
Note: __isub__ is a method of the array object.
331331
"""
332332
pass
333333

334-
def __rsub__(self: array, other: array, /) -> array:
334+
def __rsub__(self: array, other: Union[int, float, array], /) -> array:
335335
"""
336336
Note: __rsub__ is a method of the array object.
337337
"""
338338
pass
339339

340-
def __itruediv__(self: array, other: array, /) -> array:
340+
def __itruediv__(self: array, other: Union[int, float, array], /) -> array:
341341
"""
342342
Note: __itruediv__ is a method of the array object.
343343
"""
344344
pass
345345

346-
def __rtruediv__(self: array, other: array, /) -> array:
346+
def __rtruediv__(self: array, other: Union[int, float, array], /) -> array:
347347
"""
348348
Note: __rtruediv__ is a method of the array object.
349349
"""
350350
pass
351351

352-
def __ixor__(self: array, other: array, /) -> array:
352+
def __ixor__(self: array, other: Union[int, bool, array], /) -> array:
353353
"""
354354
Note: __ixor__ is a method of the array object.
355355
"""
356356
pass
357357

358-
def __rxor__(self: array, other: array, /) -> array:
358+
def __rxor__(self: array, other: Union[int, bool, array], /) -> array:
359359
"""
360360
Note: __rxor__ is a method of the array object.
361361
"""

0 commit comments

Comments
 (0)