Skip to content

Commit c54d5e0

Browse files
authored
Merge pull request #35 from asmeurer/2023.12
Preliminary 2023.12 support
2 parents f489d51 + 6f8c07f commit c54d5e0

27 files changed

+1394
-232
lines changed

.github/workflows/array-api-tests.yml

+6-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ on: [push, pull_request]
44

55
env:
66
PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200"
7+
API_VERSIONS: "2022.12 2023.12"
78

89
jobs:
910
array-api-tests:
@@ -45,9 +46,9 @@ jobs:
4546
- name: Run the array API testsuite
4647
env:
4748
ARRAY_API_TESTS_MODULE: array_api_strict
48-
# This enables the NEP 50 type promotion behavior (without it a lot of
49-
# tests fail in numpy 1.26 on bad scalar type promotion behavior)
50-
NPY_PROMOTION_STATE: weak
5149
run: |
52-
cd ${GITHUB_WORKSPACE}/array-api-tests
53-
pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS}
50+
# Parameterizing this in the CI matrix is wasteful. Just do a loop here.
51+
for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do
52+
cd ${GITHUB_WORKSPACE}/array-api-tests
53+
pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS}
54+
done

array_api_strict/__init__.py

+47-13
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
1717
"""
1818

19+
__all__ = []
20+
1921
# Warning: __array_api_version__ could change globally with
2022
# set_array_api_strict_flags(). This should always be accessed as an
2123
# attribute, like xp.__array_api_version__, or using
2224
# array_api_strict.get_array_api_strict_flags()['api_version'].
2325
from ._flags import API_VERSION as __array_api_version__
2426

25-
__all__ = ["__array_api_version__"]
27+
__all__ += ["__array_api_version__"]
2628

2729
from ._constants import e, inf, nan, pi, newaxis
2830

@@ -137,7 +139,9 @@
137139
bitwise_right_shift,
138140
bitwise_xor,
139141
ceil,
142+
clip,
140143
conj,
144+
copysign,
141145
cos,
142146
cosh,
143147
divide,
@@ -148,6 +152,7 @@
148152
floor_divide,
149153
greater,
150154
greater_equal,
155+
hypot,
151156
imag,
152157
isfinite,
153158
isinf,
@@ -163,6 +168,8 @@
163168
logical_not,
164169
logical_or,
165170
logical_xor,
171+
maximum,
172+
minimum,
166173
multiply,
167174
negative,
168175
not_equal,
@@ -172,6 +179,7 @@
172179
remainder,
173180
round,
174181
sign,
182+
signbit,
175183
sin,
176184
sinh,
177185
square,
@@ -199,7 +207,9 @@
199207
"bitwise_right_shift",
200208
"bitwise_xor",
201209
"ceil",
210+
"clip",
202211
"conj",
212+
"copysign",
203213
"cos",
204214
"cosh",
205215
"divide",
@@ -210,6 +220,7 @@
210220
"floor_divide",
211221
"greater",
212222
"greater_equal",
223+
"hypot",
213224
"imag",
214225
"isfinite",
215226
"isinf",
@@ -225,6 +236,8 @@
225236
"logical_not",
226237
"logical_or",
227238
"logical_xor",
239+
"maximum",
240+
"minimum",
228241
"multiply",
229242
"negative",
230243
"not_equal",
@@ -234,6 +247,7 @@
234247
"remainder",
235248
"round",
236249
"sign",
250+
"signbit",
237251
"sin",
238252
"sinh",
239253
"square",
@@ -248,35 +262,36 @@
248262

249263
__all__ += ["take"]
250264

251-
# linalg is an extension in the array API spec, which is a sub-namespace. Only
252-
# a subset of functions in it are imported into the top-level namespace.
253-
from . import linalg
265+
from ._info import __array_namespace_info__
254266

255-
__all__ += ["linalg"]
267+
__all__ += [
268+
"__array_namespace_info__",
269+
]
256270

257271
from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
258272

259273
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
260274

261-
from . import fft
262-
__all__ += ["fft"]
263-
264275
from ._manipulation_functions import (
265276
concat,
266277
expand_dims,
267278
flip,
279+
moveaxis,
268280
permute_dims,
281+
repeat,
269282
reshape,
270283
roll,
271284
squeeze,
272285
stack,
286+
tile,
287+
unstack,
273288
)
274289

275-
__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]
290+
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"]
276291

277-
from ._searching_functions import argmax, argmin, nonzero, where
292+
from ._searching_functions import argmax, argmin, nonzero, searchsorted, where
278293

279-
__all__ += ["argmax", "argmin", "nonzero", "where"]
294+
__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"]
280295

281296
from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
282297

@@ -286,9 +301,9 @@
286301

287302
__all__ += ["argsort", "sort"]
288303

289-
from ._statistical_functions import max, mean, min, prod, std, sum, var
304+
from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
290305

291-
__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]
306+
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
292307

293308
from ._utility_functions import all, any
294309

@@ -308,3 +323,22 @@
308323
from . import _version
309324
__version__ = _version.get_versions()['version']
310325
del _version
326+
327+
328+
# Extensions can be enabled or disabled dynamically. In order to make
329+
# "array_api_strict.linalg" give an AttributeError when it is disabled, we
330+
# use __getattr__. Note that linalg and fft are dynamically added and removed
331+
# from __all__ in set_array_api_strict_flags.
332+
333+
def __getattr__(name):
334+
if name in ['linalg', 'fft']:
335+
if name in get_array_api_strict_flags()['enabled_extensions']:
336+
if name == 'linalg':
337+
from . import _linalg
338+
return _linalg
339+
elif name == 'fft':
340+
from . import _fft
341+
return _fft
342+
else:
343+
raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict")
344+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

array_api_strict/_array_object.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __repr__(self):
5151

5252
CPU_DEVICE = _cpu_device()
5353

54+
_default = object()
55+
5456
class Array:
5557
"""
5658
n-d array object for the array API namespace.
@@ -437,7 +439,7 @@ def _validate_index(self, key):
437439
"Array API when the array is the sole index."
438440
)
439441
if not get_array_api_strict_flags()['boolean_indexing']:
440-
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict")
442+
raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict")
441443

442444
elif i.dtype in _integer_dtypes and i.ndim != 0:
443445
raise IndexError(
@@ -525,10 +527,34 @@ def __complex__(self: Array, /) -> complex:
525527
res = self._array.__complex__()
526528
return res
527529

528-
def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
530+
def __dlpack__(
531+
self: Array,
532+
/,
533+
*,
534+
stream: Optional[Union[int, Any]] = None,
535+
max_version: Optional[tuple[int, int]] = _default,
536+
dl_device: Optional[tuple[IntEnum, int]] = _default,
537+
copy: Optional[bool] = _default,
538+
) -> PyCapsule:
529539
"""
530540
Performs the operation __dlpack__.
531541
"""
542+
if get_array_api_strict_flags()['api_version'] < '2023.12':
543+
if max_version is not _default:
544+
raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API")
545+
if dl_device is not _default:
546+
raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API")
547+
if copy is not _default:
548+
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")
549+
550+
# Going to wait for upstream numpy support
551+
if max_version not in [_default, None]:
552+
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
553+
if dl_device not in [_default, None]:
554+
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
555+
if copy not in [_default, None]:
556+
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
557+
532558
return self._array.__dlpack__(stream=stream)
533559

534560
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
@@ -1142,7 +1168,7 @@ def device(self) -> Device:
11421168
# Note: mT is new in array API spec (see matrix_transpose)
11431169
@property
11441170
def mT(self) -> Array:
1145-
from .linalg import matrix_transpose
1171+
from ._linear_algebra_functions import matrix_transpose
11461172
return matrix_transpose(self)
11471173

11481174
@property

0 commit comments

Comments
 (0)