Skip to content

Commit 838f7f4

Browse files
authored
Merge pull request #82 from asmeurer/2024-draft
Add preliminary support for the draft 2024.12 version of the standard
2 parents 6afcfe1 + 61b3c90 commit 838f7f4

10 files changed

+274
-134
lines changed

array_api_strict/__init__.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,12 @@
172172
minimum,
173173
multiply,
174174
negative,
175+
nextafter,
175176
not_equal,
176177
positive,
177178
pow,
178179
real,
180+
reciprocal,
179181
remainder,
180182
round,
181183
sign,
@@ -240,10 +242,12 @@
240242
"minimum",
241243
"multiply",
242244
"negative",
245+
"nextafter",
243246
"not_equal",
244247
"positive",
245248
"pow",
246249
"real",
250+
"reciprocal",
247251
"remainder",
248252
"round",
249253
"sign",
@@ -258,9 +262,9 @@
258262
"trunc",
259263
]
260264

261-
from ._indexing_functions import take
265+
from ._indexing_functions import take, take_along_axis
262266

263-
__all__ += ["take"]
267+
__all__ += ["take", "take_along_axis"]
264268

265269
from ._info import __array_namespace_info__
266270

@@ -305,9 +309,9 @@
305309

306310
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
307311

308-
from ._utility_functions import all, any
312+
from ._utility_functions import all, any, diff
309313

310-
__all__ += ["all", "any"]
314+
__all__ += ["all", "any", "diff"]
311315

312316
from ._array_object import Device
313317
__all__ += ["Device"]

array_api_strict/_elementwise_functions.py

+25
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,20 @@ def negative(x: Array, /) -> Array:
805805
return Array._new(np.negative(x._array), device=x.device)
806806

807807

808+
@requires_api_version('2024.12')
809+
def nextafter(x1: Array, x2: Array, /) -> Array:
810+
"""
811+
Array API compatible wrapper for :py:func:`np.nextafter <numpy.nextafter>`.
812+
813+
See its docstring for more information.
814+
"""
815+
if x1.device != x2.device:
816+
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
817+
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
818+
raise TypeError("Only real floating-point dtypes are allowed in nextafter")
819+
x1, x2 = Array._normalize_two_args(x1, x2)
820+
return Array._new(np.nextafter(x1._array, x2._array), device=x1.device)
821+
808822
def not_equal(x1: Array, x2: Array, /) -> Array:
809823
"""
810824
Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
@@ -858,6 +872,17 @@ def real(x: Array, /) -> Array:
858872
return Array._new(np.real(x._array), device=x.device)
859873

860874

875+
@requires_api_version('2024.12')
876+
def reciprocal(x: Array, /) -> Array:
877+
"""
878+
Array API compatible wrapper for :py:func:`np.reciprocal <numpy.reciprocal>`.
879+
880+
See its docstring for more information.
881+
"""
882+
if x.dtype not in _floating_dtypes:
883+
raise TypeError("Only floating-point dtypes are allowed in reciprocal")
884+
return Array._new(np.reciprocal(x._array), device=x.device)
885+
861886
def remainder(x1: Array, x2: Array, /) -> Array:
862887
"""
863888
Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.

array_api_strict/_flags.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
"2023.12",
2525
)
2626

27+
draft_version = "2024.12"
28+
2729
API_VERSION = default_version = "2023.12"
2830

2931
BOOLEAN_INDEXING = True
@@ -70,8 +72,8 @@ def set_array_api_strict_flags(
7072
----------
7173
api_version : str, optional
7274
The version of the standard to use. Supported versions are:
73-
``{supported_versions}``. The default version number is
74-
``{default_version!r}``.
75+
``{supported_versions}``, plus the draft version ``{draft_version}``.
76+
The default version number is ``{default_version!r}``.
7577
7678
Note that 2021.12 is supported, but currently gives the same thing as
7779
2022.12 (except that the fft extension will be disabled).
@@ -134,10 +136,12 @@ def set_array_api_strict_flags(
134136
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
135137

136138
if api_version is not None:
137-
if api_version not in supported_versions:
139+
if api_version not in [*supported_versions, draft_version]:
138140
raise ValueError(f"Unsupported standard version {api_version!r}")
139141
if api_version == "2021.12":
140142
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2)
143+
if api_version == draft_version:
144+
warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, some functions may not be fully tested, and behaviors are subject to change before the final standard release.")
141145
API_VERSION = api_version
142146
array_api_strict.__array_api_version__ = API_VERSION
143147

@@ -169,6 +173,7 @@ def set_array_api_strict_flags(
169173
supported_versions=supported_versions,
170174
default_version=default_version,
171175
default_extensions=default_extensions,
176+
draft_version=draft_version,
172177
)
173178

174179
def get_array_api_strict_flags():

array_api_strict/_indexing_functions.py

+12
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._array_object import Array
44
from ._dtypes import _integer_dtypes
5+
from ._flags import requires_api_version
56

67
from typing import TYPE_CHECKING
78

@@ -25,3 +26,14 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
2526
if x.device != indices.device:
2627
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
2728
return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device)
29+
30+
@requires_api_version('2024.12')
31+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
32+
"""
33+
Array API compatible wrapper for :py:func:`np.take_along_axis <numpy.take_along_axis>`.
34+
35+
See its docstring for more information.
36+
"""
37+
if x.device != indices.device:
38+
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
39+
return Array._new(np.take_along_axis(x._array, indices._array, axis), device=x.device)

0 commit comments

Comments
 (0)