From 728c69ad274dfc263f8464bcc320af4e1a7895cb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:23:35 -0700 Subject: [PATCH 01/11] Support setting the version to the draft version of the standard --- array_api_strict/_flags.py | 11 ++++++++--- array_api_strict/tests/test_array_object.py | 6 +++++- array_api_strict/tests/test_flags.py | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index c393ad9..b998f43 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -24,6 +24,8 @@ "2023.12", ) +draft_version = "2024.12" + API_VERSION = default_version = "2023.12" BOOLEAN_INDEXING = True @@ -70,8 +72,8 @@ def set_array_api_strict_flags( ---------- api_version : str, optional The version of the standard to use. Supported versions are: - ``{supported_versions}``. The default version number is - ``{default_version!r}``. + ``{supported_versions}``, plus the draft version ``{draft_version}``. + The default version number is ``{default_version!r}``. Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). @@ -134,10 +136,12 @@ def set_array_api_strict_flags( global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS if api_version is not None: - if api_version not in supported_versions: + if api_version not in [*supported_versions, draft_version]: raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) + if api_version == draft_version: + warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, and behaviors are subject to change before the final standard release.") API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION @@ -169,6 +173,7 @@ def set_array_api_strict_flags( supported_versions=supported_versions, default_version=default_version, default_extensions=default_extensions, + draft_version=draft_version, ) def get_array_api_strict_flags(): diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a9ea26d..5b8dbff 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -440,8 +440,12 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" + with pytest.warns(UserWarning): + assert a.__array_namespace__(api_version="2024.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2024.12" + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2025.12")) def test_iter(): pytest.raises(TypeError, lambda: iter(asarray(3))) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 2603f35..712e464 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -99,6 +99,21 @@ def test_flags_api_version_2023_12(): 'enabled_extensions': ('linalg', 'fft'), } +def test_flags_api_version_2024_12(): + # Make sure setting the version to 2024.12 issues a warning. + with pytest.warns(UserWarning) as record: + set_array_api_strict_flags(api_version='2024.12') + assert len(record) == 1 + assert '2024.12' in str(record[0].message) + assert 'draft' in str(record[0].message) + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2024.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + def test_setting_flags_invalid(): # Test setting flags with invalid values pytest.raises(ValueError, lambda: From 31ceaae44189480ba8d404da7b3c958911d19552 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:47:08 -0700 Subject: [PATCH 02/11] Add preliminary diff() function for 2024.12 --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_utility_functions.py | 23 +++++++++++++++++++++++ array_api_strict/tests/test_flags.py | 26 ++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index ff43660..025133c 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -305,9 +305,9 @@ __all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] -from ._utility_functions import all, any +from ._utility_functions import all, any, diff -__all__ += ["all", "any"] +__all__ += ["all", "any", "diff"] from ._array_object import Device __all__ += ["Device"] diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index 0d44ecb..4cbea68 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._array_object import Array +from ._flags import requires_api_version from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -37,3 +38,25 @@ def any( See its docstring for more information. """ return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)), device=x.device) + +@requires_api_version('2024.12') +def diff( + x: Array, + /, + *, + axis: int = -1, + n: int = 1, + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: + # NumPy does not support prepend=None or append=None + kwargs = dict(axis=axis, n=n) + if prepend is not None: + if prepend.device != x.device: + raise ValueError(f"Arrays from two different devices ({prepend.device} and {x.device}) can not be combined.") + kwargs['prepend'] = prepend._array + if append is not None: + if append.device != x.device: + raise ValueError(f"Arrays from two different devices ({append.device} and {x.device}) can not be combined.") + kwargs['append'] = append._array + return Array._new(np.diff(x._array, **kwargs), device=x.device) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 712e464..7fa6828 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -282,6 +282,10 @@ def test_fft(func_name): 'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0), } +api_version_2024_12_examples = { + 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), +} + @pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) def test_api_version_2023_12(func_name): func = api_version_2023_12_examples[func_name] @@ -300,6 +304,28 @@ def test_api_version_2023_12(func_name): set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) +@pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys()) +def test_api_version_2024_12(func_name): + func = api_version_2024_12_examples[func_name] + + # By default, these functions should error + pytest.raises(RuntimeError, func) + + # In 2022.12 and 2023.12, these functions should error + set_array_api_strict_flags(api_version='2022.12') + pytest.raises(RuntimeError, func) + set_array_api_strict_flags(api_version='2023.12') + pytest.raises(RuntimeError, func) + + # They should not error in 2024.12 + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2024.12') + func() + + # Test the behavior gets updated properly + set_array_api_strict_flags(api_version='2023.12') + pytest.raises(RuntimeError, func) + def test_disabled_extensions(): # Test that xp.extension errors when an extension is disabled, and that # xp.__all__ is updated properly. From 729175f85619e1102a2dd372bf6662ab09df3778 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:47:52 -0700 Subject: [PATCH 03/11] Add warning that functions may not be fully tested --- array_api_strict/_flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index b998f43..2863e5f 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -141,7 +141,7 @@ def set_array_api_strict_flags( if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) if api_version == draft_version: - warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, and behaviors are subject to change before the final standard release.") + warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, some functions may not be fully tested, and behaviors are subject to change before the final standard release.") API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION From b2e3ecc4eaa62f496f7eeda2ca1a72c20bbb065c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:52:07 -0700 Subject: [PATCH 04/11] Require numeric types in diff --- array_api_strict/_utility_functions.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index 4cbea68..f75f36f 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._flags import requires_api_version +from ._dtypes import _numeric_dtypes from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -49,6 +50,12 @@ def diff( prepend: Optional[Array] = None, append: Optional[Array] = None, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in diff") + + # TODO: The type promotion behavior for prepend and append is not + # currently specified. + # NumPy does not support prepend=None or append=None kwargs = dict(axis=axis, n=n) if prepend is not None: From 1d111b301164023f5f580be4139d1f8047f90d57 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:56:18 -0700 Subject: [PATCH 05/11] Add draft implementation for nextafter --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 14 ++++++++++++++ .../tests/test_elementwise_functions.py | 9 +++++++-- array_api_strict/tests/test_flags.py | 1 + 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 025133c..8e6f9d7 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -172,6 +172,7 @@ minimum, multiply, negative, + nextafter, not_equal, positive, pow, @@ -240,6 +241,7 @@ "minimum", "multiply", "negative", + "nextafter", "not_equal", "positive", "pow", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 8cec86a..8daab5f 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -805,6 +805,20 @@ def negative(x: Array, /) -> Array: return Array._new(np.negative(x._array), device=x.device) +@requires_api_version('2024.12') +def nextafter(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.nextafter `. + + See its docstring for more information. + """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in nextafter") + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.nextafter(x1._array, x2._array), device=x1.device) + def not_equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.not_equal `. diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index de11edf..7aa51b6 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -2,6 +2,8 @@ from numpy.testing import assert_raises +import pytest + from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import ( @@ -79,6 +81,7 @@ def nargs(func): "minimum": "real numeric", "multiply": "numeric", "negative": "numeric", + "nextafter": "real floating-point", "not_equal": "all", "positive": "numeric", "pow": "numeric", @@ -126,7 +129,8 @@ def _array_vals(dtypes): yield asarray(1., dtype=d) # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2023.12") + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2024.12") for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] @@ -162,7 +166,8 @@ def _array_vals(): yield asarray(1.0, dtype=d) # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2023.12") + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2024.12") for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 7fa6828..31d9ecd 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -284,6 +284,7 @@ def test_fft(func_name): api_version_2024_12_examples = { 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), + 'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)), } @pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) From d9f43f4fa2160fc67c0e91b177e19ae4580a5420 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:35:55 -0700 Subject: [PATCH 06/11] Add 'max dimensions' to capabilities() for 2024.12 --- array_api_strict/_info.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 3ed7fb2..4927e97 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +import numpy as np + if TYPE_CHECKING: from typing import Optional, Union, Tuple, List from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info @@ -18,9 +20,23 @@ def __array_namespace_info__() -> Info: @requires_api_version('2023.12') def capabilities() -> Capabilities: flags = get_array_api_strict_flags() - return {"boolean indexing": flags['boolean_indexing'], + res = {"boolean indexing": flags['boolean_indexing'], "data-dependent shapes": flags['data_dependent_shapes'], } + if flags['api_version'] >= '2024.12': + # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will + # drop support for NumPy 1 but for now, just compute the number + # directly + for i in range(1, 100): + try: + np.zeros((1,)*i) + except ValueError: + maxdims = i - 1 + break + else: + raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") + res['max dimensions'] = maxdims + return res @requires_api_version('2023.12') def default_device() -> device: From 632e895af7b95c686e32ac9731b3f4bff7bd573c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:38:28 -0700 Subject: [PATCH 07/11] Add max dimensions to the Capabilities typing dict I don't know how to make this depend on API version so for now it's just there always. --- array_api_strict/_typing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index 05a479c..8fdfeda 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -54,7 +54,8 @@ class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... Capabilities = TypedDict( - "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool} + "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool, + "max dimensions": int} ) DefaultDataTypes = TypedDict( From 548f07174a26eef5b5b8501a4cba49412e90dc06 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:43:31 -0700 Subject: [PATCH 08/11] Make __array_namespace_info__ a class This makes it so that it doesn't have a bunch of extra names on it, which it did as a module. --- array_api_strict/_info.py | 247 +++++++++++++-------------- array_api_strict/_typing.py | 3 +- array_api_strict/tests/test_flags.py | 16 +- 3 files changed, 130 insertions(+), 136 deletions(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 4927e97..f288d2e 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -6,143 +6,134 @@ if TYPE_CHECKING: from typing import Optional, Union, Tuple, List - from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info + from ._typing import device, DefaultDataTypes, DataTypes, Capabilities from ._array_object import ALL_DEVICES, CPU_DEVICE from ._flags import get_array_api_strict_flags, requires_api_version from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @requires_api_version('2023.12') -def __array_namespace_info__() -> Info: - import array_api_strict._info - return array_api_strict._info - -@requires_api_version('2023.12') -def capabilities() -> Capabilities: - flags = get_array_api_strict_flags() - res = {"boolean indexing": flags['boolean_indexing'], - "data-dependent shapes": flags['data_dependent_shapes'], - } - if flags['api_version'] >= '2024.12': - # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will - # drop support for NumPy 1 but for now, just compute the number - # directly - for i in range(1, 100): - try: - np.zeros((1,)*i) - except ValueError: - maxdims = i - 1 - break - else: - raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") - res['max dimensions'] = maxdims - return res - -@requires_api_version('2023.12') -def default_device() -> device: - return CPU_DEVICE +class __array_namespace_info__: + @requires_api_version('2023.12') + def capabilities(self) -> Capabilities: + flags = get_array_api_strict_flags() + res = {"boolean indexing": flags['boolean_indexing'], + "data-dependent shapes": flags['data_dependent_shapes'], + } + if flags['api_version'] >= '2024.12': + # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will + # drop support for NumPy 1 but for now, just compute the number + # directly + for i in range(1, 100): + try: + np.zeros((1,)*i) + except ValueError: + maxdims = i - 1 + break + else: + raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") + res['max dimensions'] = maxdims + return res -@requires_api_version('2023.12') -def default_dtypes( - *, - device: Optional[device] = None, -) -> DefaultDataTypes: - return { - "real floating": float64, - "complex floating": complex128, - "integral": int64, - "indexing": int64, - } + @requires_api_version('2023.12') + def default_device(self) -> device: + return CPU_DEVICE -@requires_api_version('2023.12') -def dtypes( - *, - device: Optional[device] = None, - kind: Optional[Union[str, Tuple[str, ...]]] = None, -) -> DataTypes: - if kind is None: + @requires_api_version('2023.12') + def default_dtypes( + self, + *, + device: Optional[device] = None, + ) -> DefaultDataTypes: return { - "bool": bool, - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, + "real floating": float64, + "complex floating": complex128, + "integral": int64, + "indexing": int64, } - if kind == "bool": - return {"bool": bool} - if kind == "signed integer": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - } - if kind == "unsigned integer": - return { - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - } - if kind == "integral": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - } - if kind == "real floating": - return { - "float32": float32, - "float64": float64, - } - if kind == "complex floating": - return { - "complex64": complex64, - "complex128": complex128, - } - if kind == "numeric": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, - } - if isinstance(kind, tuple): - res = {} - for k in kind: - res.update(dtypes(kind=k)) - return res - raise ValueError(f"unsupported kind: {kind!r}") -@requires_api_version('2023.12') -def devices() -> List[device]: - return list(ALL_DEVICES) + @requires_api_version('2023.12') + def dtypes( + self, + *, + device: Optional[device] = None, + kind: Optional[Union[str, Tuple[str, ...]]] = None, + ) -> DataTypes: + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") -__all__ = [ - "capabilities", - "default_device", - "default_dtypes", - "devices", - "dtypes", -] + @requires_api_version('2023.12') + def devices(self) -> List[device]: + return list(ALL_DEVICES) diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index 8fdfeda..f13fdcf 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -29,6 +29,7 @@ from ._array_object import Array, _device from ._dtypes import _DType +from ._info import __array_namespace_info__ _T_co = TypeVar("_T_co", covariant=True) @@ -41,7 +42,7 @@ def __len__(self, /) -> int: ... Dtype = _DType -Info = ModuleType +Info = __array_namespace_info__ if sys.version_info >= (3, 12): from collections.abc import Buffer as SupportsBufferProtocol diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 31d9ecd..b6e544e 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -3,8 +3,7 @@ from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) -from .._info import (capabilities, default_device, default_dtypes, devices, - dtypes) +from .._info import __array_namespace_info__ from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft, ihfft, fftfreq, rfftfreq, fftshift, ifftshift) from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv, @@ -260,14 +259,17 @@ def test_fft(func_name): set_array_api_strict_flags(enabled_extensions=('fft',)) func() +# Test functionality even if the info object is already created +_info = xp.__array_namespace_info__() + api_version_2023_12_examples = { '__array_namespace_info__': lambda: xp.__array_namespace_info__(), # Test these functions directly to ensure they are properly decorated - 'capabilities': capabilities, - 'default_device': default_device, - 'default_dtypes': default_dtypes, - 'devices': devices, - 'dtypes': dtypes, + 'capabilities': _info.capabilities, + 'default_device': _info.default_device, + 'default_dtypes': _info.default_dtypes, + 'devices': _info.devices, + 'dtypes': _info.dtypes, 'clip': lambda: xp.clip(xp.asarray([1, 2, 3]), 1, 2), 'copysign': lambda: xp.copysign(xp.asarray([1., 2., 3.]), xp.asarray([-1., -1., -1.])), 'cumulative_sum': lambda: xp.cumulative_sum(xp.asarray([1, 2, 3])), From 9726bc096abc04c63aa804634025ad24d0aeba75 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:48:10 -0700 Subject: [PATCH 09/11] Add a draft reciprocal function for 2024.12 --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 11 +++++++++++ array_api_strict/tests/test_elementwise_functions.py | 1 + array_api_strict/tests/test_flags.py | 1 + 4 files changed, 15 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 8e6f9d7..c8c2fa6 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -177,6 +177,7 @@ positive, pow, real, + reciprocal, remainder, round, sign, @@ -246,6 +247,7 @@ "positive", "pow", "real", + "reciprocal", "remainder", "round", "sign", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 8daab5f..7c64f67 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -872,6 +872,17 @@ def real(x: Array, /) -> Array: return Array._new(np.real(x._array), device=x.device) +@requires_api_version('2024.12') +def reciprocal(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.reciprocal `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in reciprocal") + return Array._new(np.reciprocal(x._array), device=x.device) + def remainder(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.remainder `. diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 7aa51b6..4e1b9cc 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -86,6 +86,7 @@ def nargs(func): "positive": "numeric", "pow": "numeric", "real": "complex floating-point", + "reciprocal": "floating-point", "remainder": "real numeric", "round": "numeric", "sign": "numeric", diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index b6e544e..43139d1 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -287,6 +287,7 @@ def test_fft(func_name): api_version_2024_12_examples = { 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), 'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)), + 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), } @pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) From 43d60b520da6235933ef568b4c5e27ccc1dedd59 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:56:43 -0700 Subject: [PATCH 10/11] Add draft implementation of take_along_axis for 2024.12 As far as I can tell, NumPy matches the standard specification, except for the fact that NumPy does not set a default value for axis. --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_indexing_functions.py | 12 ++++++++++++ array_api_strict/tests/test_flags.py | 14 ++++++++------ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index c8c2fa6..98b0e95 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -262,9 +262,9 @@ "trunc", ] -from ._indexing_functions import take +from ._indexing_functions import take, take_along_axis -__all__ += ["take"] +__all__ += ["take", "take_along_axis"] from ._info import __array_namespace_info__ diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index c0f8e26..d7a400e 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._dtypes import _integer_dtypes +from ._flags import requires_api_version from typing import TYPE_CHECKING @@ -25,3 +26,14 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: if x.device != indices.device: raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device) + +@requires_api_version('2024.12') +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + """ + Array API compatible wrapper for :py:func:`np.take_along_axis `. + + See its docstring for more information. + """ + if x.device != indices.device: + raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") + return Array._new(np.take_along_axis(x._array, indices._array, axis), device=x.device) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 43139d1..a69a1ed 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -284,12 +284,6 @@ def test_fft(func_name): 'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0), } -api_version_2024_12_examples = { - 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), - 'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)), - 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), -} - @pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) def test_api_version_2023_12(func_name): func = api_version_2023_12_examples[func_name] @@ -308,6 +302,14 @@ def test_api_version_2023_12(func_name): set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) +api_version_2024_12_examples = { + 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), + 'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)), + 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), + 'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)), + xp.zeros((1, 4), dtype=xp.int64)), +} + @pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys()) def test_api_version_2024_12(func_name): func = api_version_2024_12_examples[func_name] From 61b3c90c8885e1fa6a8ec807022ec2ba357e2e72 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Nov 2024 14:22:40 -0700 Subject: [PATCH 11/11] Fix ruff issues --- array_api_strict/_info.py | 2 +- array_api_strict/_typing.py | 1 - array_api_strict/tests/test_flags.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index f288d2e..a9dbebf 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -130,7 +130,7 @@ def dtypes( if isinstance(kind, tuple): res = {} for k in kind: - res.update(dtypes(kind=k)) + res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index f13fdcf..94c4975 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -21,7 +21,6 @@ from typing import ( Any, - ModuleType, TypedDict, TypeVar, Protocol, diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index a69a1ed..e0b004b 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -3,7 +3,6 @@ from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) -from .._info import __array_namespace_info__ from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft, ihfft, fftfreq, rfftfreq, fftshift, ifftshift) from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv,