diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index bfb7dcf..ab7dbb8 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -4,6 +4,7 @@ on: [push, pull_request] env: PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200" + API_VERSIONS: "2022.12 2023.12" jobs: array-api-tests: @@ -45,9 +46,9 @@ jobs: - name: Run the array API testsuite env: ARRAY_API_TESTS_MODULE: array_api_strict - # This enables the NEP 50 type promotion behavior (without it a lot of - # tests fail in numpy 1.26 on bad scalar type promotion behavior) - NPY_PROMOTION_STATE: weak run: | - cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} + # Parameterizing this in the CI matrix is wasteful. Just do a loop here. + for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do + cd ${GITHUB_WORKSPACE}/array-api-tests + pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} + done diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index f4f2b39..8dfa09f 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,13 +16,15 @@ """ +__all__ = [] + # Warning: __array_api_version__ could change globally with # set_array_api_strict_flags(). This should always be accessed as an # attribute, like xp.__array_api_version__, or using # array_api_strict.get_array_api_strict_flags()['api_version']. from ._flags import API_VERSION as __array_api_version__ -__all__ = ["__array_api_version__"] +__all__ += ["__array_api_version__"] from ._constants import e, inf, nan, pi, newaxis @@ -137,7 +139,9 @@ bitwise_right_shift, bitwise_xor, ceil, + clip, conj, + copysign, cos, cosh, divide, @@ -148,6 +152,7 @@ floor_divide, greater, greater_equal, + hypot, imag, isfinite, isinf, @@ -163,6 +168,8 @@ logical_not, logical_or, logical_xor, + maximum, + minimum, multiply, negative, not_equal, @@ -172,6 +179,7 @@ remainder, round, sign, + signbit, sin, sinh, square, @@ -199,7 +207,9 @@ "bitwise_right_shift", "bitwise_xor", "ceil", + "clip", "conj", + "copysign", "cos", "cosh", "divide", @@ -210,6 +220,7 @@ "floor_divide", "greater", "greater_equal", + "hypot", "imag", "isfinite", "isinf", @@ -225,6 +236,8 @@ "logical_not", "logical_or", "logical_xor", + "maximum", + "minimum", "multiply", "negative", "not_equal", @@ -234,6 +247,7 @@ "remainder", "round", "sign", + "signbit", "sin", "sinh", "square", @@ -248,35 +262,36 @@ __all__ += ["take"] -# linalg is an extension in the array API spec, which is a sub-namespace. Only -# a subset of functions in it are imported into the top-level namespace. -from . import linalg +from ._info import __array_namespace_info__ -__all__ += ["linalg"] +__all__ += [ + "__array_namespace_info__", +] from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] -from . import fft -__all__ += ["fft"] - from ._manipulation_functions import ( concat, expand_dims, flip, + moveaxis, permute_dims, + repeat, reshape, roll, squeeze, stack, + tile, + unstack, ) -__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"] +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"] -from ._searching_functions import argmax, argmin, nonzero, where +from ._searching_functions import argmax, argmin, nonzero, searchsorted, where -__all__ += ["argmax", "argmin", "nonzero", "where"] +__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"] from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values @@ -286,9 +301,9 @@ __all__ += ["argsort", "sort"] -from ._statistical_functions import max, mean, min, prod, std, sum, var +from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var -__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"] +__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] from ._utility_functions import all, any @@ -308,3 +323,22 @@ from . import _version __version__ = _version.get_versions()['version'] del _version + + +# Extensions can be enabled or disabled dynamically. In order to make +# "array_api_strict.linalg" give an AttributeError when it is disabled, we +# use __getattr__. Note that linalg and fft are dynamically added and removed +# from __all__ in set_array_api_strict_flags. + +def __getattr__(name): + if name in ['linalg', 'fft']: + if name in get_array_api_strict_flags()['enabled_extensions']: + if name == 'linalg': + from . import _linalg + return _linalg + elif name == 'fft': + from . import _fft + return _fft + else: + raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 18ed327..cc6bd1a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -51,6 +51,8 @@ def __repr__(self): CPU_DEVICE = _cpu_device() +_default = object() + class Array: """ n-d array object for the array API namespace. @@ -437,7 +439,7 @@ def _validate_index(self, key): "Array API when the array is the sole index." ) if not get_array_api_strict_flags()['boolean_indexing']: - raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict") + raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict") elif i.dtype in _integer_dtypes and i.ndim != 0: raise IndexError( @@ -525,10 +527,34 @@ def __complex__(self: Array, /) -> complex: res = self._array.__complex__() return res - def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule: + def __dlpack__( + self: Array, + /, + *, + stream: Optional[Union[int, Any]] = None, + max_version: Optional[tuple[int, int]] = _default, + dl_device: Optional[tuple[IntEnum, int]] = _default, + copy: Optional[bool] = _default, + ) -> PyCapsule: """ Performs the operation __dlpack__. """ + if get_array_api_strict_flags()['api_version'] < '2023.12': + if max_version is not _default: + raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API") + if dl_device is not _default: + raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API") + if copy is not _default: + raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") + + # Going to wait for upstream numpy support + if max_version not in [_default, None]: + raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") + if dl_device not in [_default, None]: + raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") + return self._array.__dlpack__(stream=stream) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: @@ -1142,7 +1168,7 @@ def device(self) -> Device: # Note: mT is new in array API spec (see matrix_transpose) @property def mT(self) -> Array: - from .linalg import matrix_transpose + from ._linear_algebra_functions import matrix_transpose return matrix_transpose(self) @property diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index ad7ec82..67ba67c 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -12,6 +12,7 @@ SupportsBufferProtocol, ) from ._dtypes import _DType, _all_dtypes +from ._flags import get_array_api_strict_flags import numpy as np @@ -19,7 +20,7 @@ def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. if dtype not in (None,) + _all_dtypes: - raise ValueError("dtype must be one of the supported dtypes") + raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}") def _supports_buffer_protocol(obj): try: @@ -28,6 +29,14 @@ def _supports_buffer_protocol(obj): return False return True +def _check_device(device): + # _array_object imports in this file are inside the functions to avoid + # circular imports + from ._array_object import CPU_DEVICE + + if device not in [CPU_DEVICE, None]: + raise ValueError(f"Unsupported device {device!r}") + def asarray( obj: Union[ Array, @@ -48,16 +57,13 @@ def asarray( See its docstring for more information. """ - # _array_object imports in this file are inside the functions to avoid - # circular imports - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) _np_dtype = None if dtype is not None: _np_dtype = dtype._np_dtype - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) if np.__version__[0] < '2': if copy is False: @@ -106,11 +112,11 @@ def arange( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype)) @@ -127,11 +133,11 @@ def empty( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.empty(shape, dtype=dtype)) @@ -145,11 +151,11 @@ def empty_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.empty_like(x._array, dtype=dtype)) @@ -169,19 +175,39 @@ def eye( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) -def from_dlpack(x: object, /) -> Array: +_default = object() + +def from_dlpack( + x: object, + /, + *, + device: Optional[Device] = _default, + copy: Optional[bool] = _default, +) -> Array: from ._array_object import Array + if get_array_api_strict_flags()['api_version'] < '2023.12': + if device is not _default: + raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API") + if copy is not _default: + raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API") + + # Going to wait for upstream numpy support + if device is not _default: + _check_device(device) + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") + return Array._new(np.from_dlpack(x)) @@ -197,11 +223,11 @@ def full( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if isinstance(fill_value, Array) and fill_value.ndim == 0: fill_value = fill_value._array if dtype is not None: @@ -227,11 +253,11 @@ def full_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype res = np.full_like(x._array, fill_value, dtype=dtype) @@ -257,11 +283,11 @@ def linspace( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) @@ -298,11 +324,11 @@ def ones( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.ones(shape, dtype=dtype)) @@ -316,11 +342,11 @@ def ones_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.ones_like(x._array, dtype=dtype)) @@ -365,11 +391,11 @@ def zeros( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.zeros(shape, dtype=dtype)) @@ -383,11 +409,11 @@ def zeros_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.zeros_like(x._array, dtype=dtype)) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 41f70c5..3405710 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._array_object import Array +from ._creation_functions import _check_device from ._dtypes import ( _DType, _all_dtypes, @@ -13,19 +14,30 @@ _numeric_dtypes, _result_type, ) +from ._flags import get_array_api_strict_flags from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import List, Tuple, Union - from ._typing import Dtype + from typing import List, Tuple, Union, Optional + from ._typing import Dtype, Device import numpy as np +# Use to emulate the asarray(device) argument not existing in 2022.12 +_default = object() # Note: astype is a function, not an array method as in NumPy. -def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array: +def astype( + x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = _default +) -> Array: + if device is not _default: + if get_array_api_strict_flags()['api_version'] >= '2023.12': + _check_device(device) + else: + raise TypeError("The device argument to astype requires at least version 2023.12 of the array API") + if not copy and dtype == x.dtype: return x return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy)) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 8b69677..b39bd86 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -12,6 +12,10 @@ _result_type, ) from ._array_object import Array +from ._flags import requires_api_version +from ._creation_functions import asarray + +from typing import Optional, Union import numpy as np @@ -240,6 +244,70 @@ def ceil(x: Array, /) -> Array: return x return Array._new(np.ceil(x._array)) +# WARNING: This function is not yet tested by the array-api-tests test suite. + +# Note: min and max argument names are different and not optional in numpy. +@requires_api_version('2023.12') +def clip( + x: Array, + /, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.clip `. + + See its docstring for more information. + """ + if (x.dtype not in _real_numeric_dtypes + or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes + or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes): + raise TypeError("Only real numeric dtypes are allowed in clip") + if not isinstance(min, (int, float, Array, type(None))): + raise TypeError("min must be an None, int, float, or an array") + if not isinstance(max, (int, float, Array, type(None))): + raise TypeError("max must be an None, int, float, or an array") + + # Mixed dtype kinds is implementation defined + if (x.dtype in _integer_dtypes + and (isinstance(min, float) or + isinstance(min, Array) and min.dtype in _real_floating_dtypes)): + raise TypeError("min must be integral when x is integral") + if (x.dtype in _integer_dtypes + and (isinstance(max, float) or + isinstance(max, Array) and max.dtype in _real_floating_dtypes)): + raise TypeError("max must be integral when x is integral") + if (x.dtype in _real_floating_dtypes + and (isinstance(min, int) or + isinstance(min, Array) and min.dtype in _integer_dtypes)): + raise TypeError("min must be floating-point when x is floating-point") + if (x.dtype in _real_floating_dtypes + and (isinstance(max, int) or + isinstance(max, Array) and max.dtype in _integer_dtypes)): + raise TypeError("max must be floating-point when x is floating-point") + + if min is max is None: + # Note: NumPy disallows min = max = None + return x + + # Normalize to make the below logic simpler + if min is not None: + min = asarray(min)._array + if max is not None: + max = asarray(max)._array + + # min > max is implementation defined + if min is not None and max is not None and np.any(min > max): + raise ValueError("min must be less than or equal to max") + + result = np.clip(x._array, min, max) + # Note: NumPy applies type promotion, but the standard specifies the + # return dtype should be the same as x + if result.dtype != x.dtype._np_dtype: + # TODO: I'm not completely sure this always gives the correct thing + # for integer dtypes. See https://github.com/numpy/numpy/issues/24976 + result = result.astype(x.dtype._np_dtype) + return Array._new(result) def conj(x: Array, /) -> Array: """ @@ -251,6 +319,19 @@ def conj(x: Array, /) -> Array: raise TypeError("Only complex floating-point dtypes are allowed in conj") return Array._new(np.conj(x)) +@requires_api_version('2023.12') +def copysign(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.copysign `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in copysign") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.copysign(x1._array, x2._array)) def cos(x: Array, /) -> Array: """ @@ -377,6 +458,19 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) +@requires_api_version('2023.12') +def hypot(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.hypot `. + + See its docstring for more information. + """ + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in hypot") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.hypot(x1._array, x2._array)) def imag(x: Array, /) -> Array: """ @@ -560,6 +654,35 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_xor(x1._array, x2._array)) +@requires_api_version('2023.12') +def maximum(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.maximum `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in maximum") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + # TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error + # in that case? + return Array._new(np.maximum(x1._array, x2._array)) + +@requires_api_version('2023.12') +def minimum(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.minimum `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in minimum") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.minimum(x1._array, x2._array)) def multiply(x1: Array, x2: Array, /) -> Array: """ @@ -671,6 +794,18 @@ def sign(x: Array, /) -> Array: return Array._new(np.sign(x._array)) +@requires_api_version('2023.12') +def signbit(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.signbit `. + + See its docstring for more information. + """ + if x.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in signbit") + return Array._new(np.signbit(x._array)) + + def sin(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sin `. diff --git a/array_api_strict/fft.py b/array_api_strict/_fft.py similarity index 100% rename from array_api_strict/fft.py rename to array_api_strict/_fft.py diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 3bf5664..f6cef29 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -21,6 +21,7 @@ supported_versions = ( "2021.12", "2022.12", + "2023.12", ) API_VERSION = default_version = "2022.12" @@ -70,6 +71,8 @@ def set_array_api_strict_flags( Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). + 2023.12 support is experimental. Some features in 2023.12 may still be + missing, and it hasn't been fully tested. - `boolean_indexing`: Whether indexing by a boolean array is supported. Note that although boolean array indexing does result in data-dependent @@ -86,9 +89,9 @@ def set_array_api_strict_flags( The functions that make use of data-dependent shapes, and are therefore disabled by setting this flag to False are - - `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`. - - `nonzero` - - `repeat` when the `repeats` argument is an array (requires 2023.12 + - `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`. + - `nonzero()` + - `repeat()` when the `repeats` argument is an array (requires 2023.12 version of the standard) Note that while boolean indexing is also data-dependent, it is @@ -133,7 +136,9 @@ def set_array_api_strict_flags( if api_version not in supported_versions: raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") + warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) + if api_version == "2023.12": + warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.", stacklevel=2) API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION @@ -154,7 +159,11 @@ def set_array_api_strict_flags( ) ENABLED_EXTENSIONS = tuple(enabled_extensions) else: - ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= API_VERSION]) + ENABLED_EXTENSIONS = tuple([ext for ext in ENABLED_EXTENSIONS if extension_versions[ext] <= API_VERSION]) + + array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) | + set(array_api_strict.__all__) - + set(default_extensions)) # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( @@ -172,6 +181,14 @@ def get_array_api_strict_flags(): This function is **not** part of the array API standard. It only exists in array-api-strict. + .. note:: + + The `inspection API + `__ + provides a portable way to access most of this information. However, it + is only present in standard versions starting with 2023.12. The array + API version can be accessed portably using `xp.__array_api_version__`. + Returns ------- dict @@ -280,29 +297,51 @@ def __exit__(self, exc_type, exc_value, traceback): # Private functions +ENVIRONMENT_VARIABLES = [ + "ARRAY_API_STRICT_API_VERSION", + "ARRAY_API_STRICT_BOOLEAN_INDEXING", + "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES", + "ARRAY_API_STRICT_ENABLED_EXTENSIONS", +] + def set_flags_from_environment(): + kwargs = {} if "ARRAY_API_STRICT_API_VERSION" in os.environ: - set_array_api_strict_flags( - api_version=os.environ["ARRAY_API_STRICT_API_VERSION"] - ) + kwargs["api_version"] = os.environ["ARRAY_API_STRICT_API_VERSION"] if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os.environ: - set_array_api_strict_flags( - boolean_indexing=os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true" - ) + kwargs["boolean_indexing"] = os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true" if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: - set_array_api_strict_flags( - data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" - ) + kwargs["data_dependent_shapes"] = os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: - set_array_api_strict_flags( - enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") - ) + enabled_extensions = os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") + if enabled_extensions == [""]: + enabled_extensions = [] + kwargs["enabled_extensions"] = enabled_extensions + + # Called unconditionally because it is needed at first import to add + # linalg and fft to __all__ + set_array_api_strict_flags(**kwargs) set_flags_from_environment() +# Decorators + +def requires_api_version(version): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if version > API_VERSION: + raise RuntimeError( + f"The function {func.__name__} requires API version {version} or later, " + f"but the current API version for array-api-strict is {API_VERSION}" + ) + return func(*args, **kwargs) + return wrapper + return decorator + def requires_data_dependent_shapes(func): @functools.wraps(func) def wrapper(*args, **kwargs): diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py new file mode 100644 index 0000000..ab5447a --- /dev/null +++ b/array_api_strict/_info.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Optional, Union, Tuple, List + from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info + +from ._array_object import CPU_DEVICE +from ._flags import get_array_api_strict_flags, requires_api_version +from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 + +@requires_api_version('2023.12') +def __array_namespace_info__() -> Info: + import array_api_strict._info + return array_api_strict._info + +@requires_api_version('2023.12') +def capabilities() -> Capabilities: + flags = get_array_api_strict_flags() + return {"boolean indexing": flags['boolean_indexing'], + "data-dependent shapes": flags['data_dependent_shapes'], + } + +@requires_api_version('2023.12') +def default_device() -> device: + return CPU_DEVICE + +@requires_api_version('2023.12') +def default_dtypes( + *, + device: Optional[device] = None, +) -> DefaultDataTypes: + return { + "real floating": float64, + "complex floating": complex128, + "integral": int64, + "indexing": int64, + } + +@requires_api_version('2023.12') +def dtypes( + *, + device: Optional[device] = None, + kind: Optional[Union[str, Tuple[str, ...]]] = None, +) -> DataTypes: + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + +@requires_api_version('2023.12') +def devices() -> List[device]: + return [CPU_DEVICE] + +__all__ = [ + "capabilities", + "default_device", + "default_dtypes", + "devices", + "dtypes", +] diff --git a/array_api_strict/linalg.py b/array_api_strict/_linalg.py similarity index 95% rename from array_api_strict/linalg.py rename to array_api_strict/_linalg.py index 1f548f0..bd11aa4 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/_linalg.py @@ -11,7 +11,7 @@ from ._manipulation_functions import reshape from ._elementwise_functions import conj from ._array_object import Array -from ._flags import requires_extension +from ._flags import requires_extension, get_array_api_strict_flags try: from numpy._core.numeric import normalize_axis_tuple @@ -80,6 +80,17 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: # Note: this is different from np.cross(), which allows dimension 2 if x1.shape[axis] != 3: raise ValueError('cross() dimension must equal 3') + + if get_array_api_strict_flags()['api_version'] >= '2023.12': + if axis >= 0: + raise ValueError("axis must be negative in cross") + elif axis < min(-1, -x1.ndim, -x2.ndim): + raise ValueError("axis is out of bounds for x1 and x2") + + # Prior to 2023.12, there was ambiguity in the standard about whether + # positive axis applied before or after broadcasting. NumPy applies + # the axis before broadcasting. Since that behavior is what has always + # been implemented here, we keep it for backwards compatibility. return Array._new(np.cross(x1._array, x2._array, axis=axis)) @requires_extension('linalg') @@ -377,10 +388,11 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr # Note: trace() works the same as sum() and prod() (see # _statistical_functions.py) if dtype is None: - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 + if get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + dtype = np.float64 + elif x.dtype == complex64: + dtype = np.complex128 else: dtype = dtype._np_dtype # Note: trace always operates on the last two axes, whereas np.trace diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index 1ff08d4..dcb654d 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -8,8 +8,8 @@ from __future__ import annotations from ._dtypes import _numeric_dtypes - from ._array_object import Array +from ._flags import get_array_api_strict_flags from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -54,6 +54,19 @@ def matrix_transpose(x: Array, /) -> Array: def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in vecdot') + + if get_array_api_strict_flags()['api_version'] >= '2023.12': + if axis >= 0: + raise ValueError("axis must be negative in vecdot") + elif axis < min(-1, -x1.ndim, -x2.ndim): + raise ValueError("axis is out of bounds for x1 and x2") + + # In versions of the standard prior to 2023.12, vecdot applied axis after + # broadcasting. This is different from applying it before broadcasting + # when axis is nonnegative. The below code keeps this behavior for + # 2022.12, primarily for backwards compatibility. Note that the behavior + # is unambiguous when axis is negative, so the below code should work + # correctly in that case regardless of which version is used. ndim = max(x1.ndim, x2.ndim) x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index af9a3dd..7652028 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -1,7 +1,10 @@ from __future__ import annotations from ._array_object import Array +from ._creation_functions import asarray from ._data_type_functions import result_type +from ._dtypes import _integer_dtypes +from ._flags import requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING @@ -43,6 +46,19 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> """ return Array._new(np.flip(x._array, axis=axis)) +@requires_api_version('2023.12') +def moveaxis( + x: Array, + source: Union[int, Tuple[int, ...]], + destination: Union[int, Tuple[int, ...]], + /, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.moveaxis `. + + See its docstring for more information. + """ + return Array._new(np.moveaxis(x._array, source, destination)) # Note: The function name is different here (see also matrix_transpose). # Unlike transpose(), the axes argument is required. @@ -54,6 +70,31 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: """ return Array._new(np.transpose(x._array, axes)) +@requires_api_version('2023.12') +def repeat( + x: Array, + repeats: Union[int, Array], + /, + *, + axis: Optional[int] = None, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.repeat `. + + See its docstring for more information. + """ + if isinstance(repeats, Array): + data_dependent_shapes = get_array_api_strict_flags()['data_dependent_shapes'] + if not data_dependent_shapes: + raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + if repeats.dtype not in _integer_dtypes: + raise TypeError("The repeats array must have an integer dtype") + elif isinstance(repeats, int): + repeats = asarray(repeats) + else: + raise TypeError("repeats must be an int or array") + + return Array._new(np.repeat(x._array, repeats, axis=axis)) # Note: the optional argument is called 'shape', not 'newshape' def reshape(x: Array, @@ -113,3 +154,28 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> result_type(*arrays) arrays = tuple(a._array for a in arrays) return Array._new(np.stack(arrays, axis=axis)) + + +@requires_api_version('2023.12') +def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.tile `. + + See its docstring for more information. + """ + # Note: NumPy allows repetitions to be an int or array + if not isinstance(repetitions, tuple): + raise TypeError("repetitions must be a tuple") + return Array._new(np.tile(x._array, repetitions)) + +# Note: this function is new +@requires_api_version('2023.12') +def unstack(x: Array, /, *, axis: int = 0) -> Tuple[Array, ...]: + if not (-x.ndim <= axis < x.ndim): + raise ValueError("axis out of range") + + if axis < 0: + axis += x.ndim + + slices = (slice(None),) * axis + return tuple(x[slices + (i, ...)] for i in range(x.shape[axis])) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 1ef2556..7314895 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -2,11 +2,11 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes -from ._flags import requires_data_dependent_shapes +from ._flags import requires_data_dependent_shapes, requires_api_version from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Tuple + from typing import Literal, Optional, Tuple import numpy as np @@ -45,6 +45,28 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: raise ValueError("nonzero is not allowed on 0-dimensional arrays") return tuple(Array._new(i) for i in np.nonzero(x._array)) +@requires_api_version('2023.12') +def searchsorted( + x1: Array, + x2: Array, + /, + *, + side: Literal["left", "right"] = "left", + sorter: Optional[Array] = None, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.searchsorted `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in searchsorted") + sorter = sorter._array if sorter is not None else None + # TODO: The sort order of nans and signed zeros is implementation + # dependent. Should we error/warn if they are present? + + # x1 must be 1-D, but NumPy already requires this. + return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter)) def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index cbe9d0d..39e3736 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -7,6 +7,9 @@ ) from ._array_object import Array from ._dtypes import float32, complex64 +from ._flags import requires_api_version, get_array_api_strict_flags +from ._creation_functions import zeros +from ._manipulation_functions import concat from typing import TYPE_CHECKING @@ -16,6 +19,32 @@ import numpy as np +@requires_api_version('2023.12') +def cumulative_sum( + x: Array, + /, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in cumulative_sum") + dt = x.dtype if dtype is None else dtype + if dtype is not None: + dtype = dtype._np_dtype + + # TODO: The standard is not clear about what should happen when x.ndim == 0. + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + axis = 0 + # np.cumsum does not support include_initial + if include_initial: + if axis < 0: + axis += x.ndim + x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) + return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype)) def max( x: Array, @@ -63,14 +92,16 @@ def prod( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in prod") - # Note: sum() and prod() always upcast for dtype=None. `np.prod` does that - # for integers, but not for float32 or complex64, so we need to - # special-case it here + if dtype is None: - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 + # Note: In versions prior to 2023.12, sum() and prod() upcast for all + # dtypes when dtype=None. For 2023.12, the behavior is the same as in + # NumPy (only upcast for integral dtypes). + if get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + dtype = np.float64 + elif x.dtype == complex64: + dtype = np.complex128 else: dtype = dtype._np_dtype return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) @@ -100,14 +131,16 @@ def sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sum") - # Note: sum() and prod() always upcast for dtype=None. `np.sum` does that - # for integers, but not for float32 or complex64, so we need to - # special-case it here + if dtype is None: - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 + # Note: In versions prior to 2023.12, sum() and prod() upcast for all + # dtypes when dtype=None. For 2023.12, the behavior is the same as in + # NumPy (only upcast for integral dtypes). + if get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + dtype = np.float64 + elif x.dtype == complex64: + dtype = np.complex128 else: dtype = dtype._np_dtype return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index ce25d4c..eb1b834 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -21,6 +21,8 @@ from typing import ( Any, + ModuleType, + TypedDict, TypeVar, Protocol, ) @@ -39,6 +41,8 @@ def __len__(self, /) -> int: ... Dtype = _DType +Info = ModuleType + if sys.version_info >= (3, 12): from collections.abc import Buffer as SupportsBufferProtocol else: @@ -48,3 +52,37 @@ def __len__(self, /) -> int: ... class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... + +Capabilities = TypedDict( + "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool} +) + +DefaultDataTypes = TypedDict( + "DefaultDataTypes", + { + "real floating": Dtype, + "complex floating": Dtype, + "integral": Dtype, + "indexing": Dtype, + }, +) + +DataTypes = TypedDict( + "DataTypes", + { + "bool": Dtype, + "float32": Dtype, + "float64": Dtype, + "complex64": Dtype, + "complex128": Dtype, + "int8": Dtype, + "int16": Dtype, + "int32": Dtype, + "int64": Dtype, + "uint8": Dtype, + "uint16": Dtype, + "uint32": Dtype, + "uint64": Dtype, + }, + total=False, +) diff --git a/array_api_strict/tests/conftest.py b/array_api_strict/tests/conftest.py index 5000d5d..1a9d507 100644 --- a/array_api_strict/tests/conftest.py +++ b/array_api_strict/tests/conftest.py @@ -1,7 +1,14 @@ -from .._flags import reset_array_api_strict_flags +import os + +from .._flags import reset_array_api_strict_flags, ENVIRONMENT_VARIABLES import pytest +def pytest_configure(config): + for env_var in ENVIRONMENT_VARIABLES: + if env_var in os.environ: + pytest.exit(f"ERROR: {env_var} is set. array-api-strict environment variables must not be set when the tests are run.") + @pytest.fixture(autouse=True) def reset_flags(): reset_array_api_strict_flags() diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 407bff2..b28c747 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -23,6 +23,8 @@ uint64, bool as bool_, ) +from .._flags import set_array_api_strict_flags + import array_api_strict def test_validate_index(): @@ -410,13 +412,46 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2022.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2022.12" + with pytest.warns(UserWarning): + assert a.__array_namespace__(api_version="2023.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2023.12" + with pytest.warns(UserWarning): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) def test_no_iter(): pytest.raises(TypeError, lambda: iter(ones(3))) pytest.raises(TypeError, lambda: iter(ones((3, 3)))) + +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) +def dlpack_2023_12(api_version): + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = asarray([1, 2, 3], dtype=int8) + # Never an error + a.__dlpack__() + + + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=CPU_DEVICE)) + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=None)) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=(1, 0))) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=None)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=False)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=True)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=None)) diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 78d4c80..819afad 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -3,6 +3,8 @@ from numpy.testing import assert_raises import numpy as np +import pytest + from .. import all from .._creation_functions import ( asarray, @@ -10,6 +12,7 @@ empty, empty_like, eye, + from_dlpack, full, full_like, linspace, @@ -21,7 +24,7 @@ ) from .._dtypes import float32, float64 from .._array_object import Array, CPU_DEVICE - +from .._flags import set_array_api_strict_flags def test_asarray_errors(): # Test various protections against incorrect usage @@ -188,3 +191,24 @@ def test_meshgrid_dtype_errors(): meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32)) assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64))) + + +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) +def from_dlpack_2023_12(api_version): + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = asarray([1., 2., 3.], dtype=float64) + # Never an error + capsule = a.__dlpack__() + from_dlpack(capsule) + + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: from_dlpack(capsule, device=CPU_DEVICE)) + pytest.raises(exception, lambda: from_dlpack(capsule, device=None)) + pytest.raises(exception, lambda: from_dlpack(capsule, copy=False)) + pytest.raises(exception, lambda: from_dlpack(capsule, copy=True)) + pytest.raises(exception, lambda: from_dlpack(capsule, copy=None)) diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 60a7f29..40cab55 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -3,38 +3,68 @@ import pytest from numpy.testing import assert_raises -import array_api_strict as xp import numpy as np +from .._creation_functions import asarray +from .._data_type_functions import astype, can_cast, isdtype +from .._dtypes import ( + bool, int8, int16, uint8, float64, +) +from .._flags import set_array_api_strict_flags + + @pytest.mark.parametrize( "from_, to, expected", [ - (xp.int8, xp.int16, True), - (xp.int16, xp.int8, False), - (xp.bool, xp.int8, False), - (xp.asarray(0, dtype=xp.uint8), xp.int8, False), + (int8, int16, True), + (int16, int8, False), + (bool, int8, False), + (asarray(0, dtype=uint8), int8, False), ], ) def test_can_cast(from_, to, expected): """ can_cast() returns correct result """ - assert xp.can_cast(from_, to) == expected + assert can_cast(from_, to) == expected def test_isdtype_strictness(): - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, 64)) - assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8')) + assert_raises(TypeError, lambda: isdtype(float64, 64)) + assert_raises(ValueError, lambda: isdtype(float64, 'f8')) - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),))) + assert_raises(TypeError, lambda: isdtype(float64, (('integral',),))) with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - xp.isdtype(xp.float64, np.object_) + isdtype(float64, np.object_) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None)) + assert_raises(TypeError, lambda: isdtype(float64, None)) with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - xp.isdtype(xp.float64, np.float64) + isdtype(float64, np.float64) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) + + +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) +def astype_device(api_version): + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = asarray([1, 2, 3], dtype=int8) + # Never an error + astype(a, int16) + + # Always an error + astype(a, int16, device="cpu") + + if api_version >= '2023.12': + astype(a, int8, device=None) + astype(a, int8, device=a.device) + else: + pytest.raises(TypeError, lambda: astype(a, int8, device=None)) + pytest.raises(TypeError, lambda: astype(a, int8, device=a.device)) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 1228d0a..90994f3 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,4 +1,4 @@ -from inspect import getfullargspec +from inspect import getfullargspec, getmodule from numpy.testing import assert_raises @@ -10,79 +10,93 @@ _floating_dtypes, _integer_dtypes, ) +from .._flags import set_array_api_strict_flags +import pytest def nargs(func): return len(getfullargspec(func).args) +elementwise_function_input_types = { + "abs": "numeric", + "acos": "floating-point", + "acosh": "floating-point", + "add": "numeric", + "asin": "floating-point", + "asinh": "floating-point", + "atan": "floating-point", + "atan2": "real floating-point", + "atanh": "floating-point", + "bitwise_and": "integer or boolean", + "bitwise_invert": "integer or boolean", + "bitwise_left_shift": "integer", + "bitwise_or": "integer or boolean", + "bitwise_right_shift": "integer", + "bitwise_xor": "integer or boolean", + "ceil": "real numeric", + "clip": "real numeric", + "conj": "complex floating-point", + "copysign": "real floating-point", + "cos": "floating-point", + "cosh": "floating-point", + "divide": "floating-point", + "equal": "all", + "exp": "floating-point", + "expm1": "floating-point", + "floor": "real numeric", + "floor_divide": "real numeric", + "greater": "real numeric", + "greater_equal": "real numeric", + "hypot": "real floating-point", + "imag": "complex floating-point", + "isfinite": "numeric", + "isinf": "numeric", + "isnan": "numeric", + "less": "real numeric", + "less_equal": "real numeric", + "log": "floating-point", + "logaddexp": "real floating-point", + "log10": "floating-point", + "log1p": "floating-point", + "log2": "floating-point", + "logical_and": "boolean", + "logical_not": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "maximum": "real numeric", + "minimum": "real numeric", + "multiply": "numeric", + "negative": "numeric", + "not_equal": "all", + "positive": "numeric", + "pow": "numeric", + "real": "complex floating-point", + "remainder": "real numeric", + "round": "numeric", + "sign": "numeric", + "signbit": "real floating-point", + "sin": "floating-point", + "sinh": "floating-point", + "sqrt": "floating-point", + "square": "numeric", + "subtract": "numeric", + "tan": "floating-point", + "tanh": "floating-point", + "trunc": "real numeric", +} + +def test_missing_functions(): + # Ensure the above dictionary is complete. + import array_api_strict._elementwise_functions as mod + mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod] + assert set(mod_funcs) == set(elementwise_function_input_types) + def test_function_types(): # Test that every function accepts only the required input types. We only # test the negative cases here (error). The positive cases are tested in # the array API test suite. - elementwise_function_input_types = { - "abs": "numeric", - "acos": "floating-point", - "acosh": "floating-point", - "add": "numeric", - "asin": "floating-point", - "asinh": "floating-point", - "atan": "floating-point", - "atan2": "real floating-point", - "atanh": "floating-point", - "bitwise_and": "integer or boolean", - "bitwise_invert": "integer or boolean", - "bitwise_left_shift": "integer", - "bitwise_or": "integer or boolean", - "bitwise_right_shift": "integer", - "bitwise_xor": "integer or boolean", - "ceil": "real numeric", - "conj": "complex floating-point", - "cos": "floating-point", - "cosh": "floating-point", - "divide": "floating-point", - "equal": "all", - "exp": "floating-point", - "expm1": "floating-point", - "floor": "real numeric", - "floor_divide": "real numeric", - "greater": "real numeric", - "greater_equal": "real numeric", - "imag": "complex floating-point", - "isfinite": "numeric", - "isinf": "numeric", - "isnan": "numeric", - "less": "real numeric", - "less_equal": "real numeric", - "log": "floating-point", - "logaddexp": "real floating-point", - "log10": "floating-point", - "log1p": "floating-point", - "log2": "floating-point", - "logical_and": "boolean", - "logical_not": "boolean", - "logical_or": "boolean", - "logical_xor": "boolean", - "multiply": "numeric", - "negative": "numeric", - "not_equal": "all", - "positive": "numeric", - "pow": "numeric", - "real": "complex floating-point", - "remainder": "real numeric", - "round": "numeric", - "sign": "numeric", - "sin": "floating-point", - "sinh": "floating-point", - "sqrt": "floating-point", - "square": "numeric", - "subtract": "numeric", - "tan": "floating-point", - "tanh": "floating-point", - "trunc": "real numeric", - } - def _array_vals(): for d in _integer_dtypes: yield asarray(1, dtype=d) @@ -91,6 +105,10 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) + # Use the latest version of the standard so all functions are included + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2023.12") + for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index dcf4522..86ad8e2 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -1,8 +1,18 @@ +import sys +import subprocess + from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) +from .._info import (capabilities, default_device, default_dtypes, devices, + dtypes) +from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft, + ihfft, fftfreq, rfftfreq, fftshift, ifftshift) +from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv, + matmul, matrix_norm, matrix_power, matrix_rank, matrix_transpose, outer, pinv, + qr, slogdet, solve, svd, svdvals, tensordot, trace, vecdot, vector_norm) from .. import (asarray, unique_all, unique_counts, unique_inverse, - unique_values, nonzero) + unique_values, nonzero, repeat) import array_api_strict as xp @@ -46,17 +56,43 @@ def test_flags(): 'api_version': '2021.12', 'boolean_indexing': True, 'data_dependent_shapes': False, + 'enabled_extensions': (), + } + reset_array_api_strict_flags() + + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2021.12') + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2021.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, 'enabled_extensions': ('linalg',), } + reset_array_api_strict_flags() + + # 2023.12 should issue a warning + with pytest.warns(UserWarning) as record: + set_array_api_strict_flags(api_version='2023.12') + assert len(record) == 1 + assert '2023.12' in str(record[0].message) + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2023.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } # Test setting flags with invalid values pytest.raises(ValueError, lambda: set_array_api_strict_flags(api_version='2020.12')) pytest.raises(ValueError, lambda: set_array_api_strict_flags( enabled_extensions=('linalg', 'fft', 'invalid'))) - pytest.raises(ValueError, lambda: set_array_api_strict_flags( - api_version='2021.12', - enabled_extensions=('linalg', 'fft'))) + with pytest.warns(UserWarning): + pytest.raises(ValueError, lambda: set_array_api_strict_flags( + api_version='2021.12', + enabled_extensions=('linalg', 'fft'))) # Test resetting flags with pytest.warns(UserWarning): @@ -79,12 +115,17 @@ def test_api_version(): assert xp.__array_api_version__ == '2022.12' # Test setting the version - set_array_api_strict_flags(api_version='2021.12') + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2021.12') assert xp.__array_api_version__ == '2021.12' def test_data_dependent_shapes(): + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') # to enable repeat() + a = asarray([0, 0, 1, 2, 2]) mask = asarray([True, False, True, False, True]) + repeats = asarray([1, 1, 2, 2, 2]) # Should not error unique_all(a) @@ -93,7 +134,8 @@ def test_data_dependent_shapes(): unique_values(a) nonzero(a) a[mask] - # TODO: add repeat when it is implemented + repeat(a, repeats) + repeat(a, 2) set_array_api_strict_flags(data_dependent_shapes=False) @@ -102,6 +144,8 @@ def test_data_dependent_shapes(): pytest.raises(RuntimeError, lambda: unique_inverse(a)) pytest.raises(RuntimeError, lambda: unique_values(a)) pytest.raises(RuntimeError, lambda: nonzero(a)) + pytest.raises(RuntimeError, lambda: repeat(a, repeats)) + repeat(a, 2) # Should never error a[mask] # No error (boolean indexing is a separate flag) def test_boolean_indexing(): @@ -116,29 +160,29 @@ def test_boolean_indexing(): pytest.raises(RuntimeError, lambda: a[mask]) linalg_examples = { - 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), - 'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), - 'det': lambda: xp.linalg.det(xp.eye(3)), - 'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)), - 'eigh': lambda: xp.linalg.eigh(xp.eye(3)), - 'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)), - 'inv': lambda: xp.linalg.inv(xp.eye(3)), - 'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)), - 'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)), - 'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2), - 'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)), - 'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)), - 'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), - 'pinv': lambda: xp.linalg.pinv(xp.eye(3)), - 'qr': lambda: xp.linalg.qr(xp.eye(3)), - 'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)), - 'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)), - 'svd': lambda: xp.linalg.svd(xp.eye(3)), - 'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)), - 'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)), - 'trace': lambda: xp.linalg.trace(xp.eye(3)), - 'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), - 'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])), + 'cholesky': lambda: cholesky(xp.eye(3)), + 'cross': lambda: cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), + 'det': lambda: det(xp.eye(3)), + 'diagonal': lambda: diagonal(xp.eye(3)), + 'eigh': lambda: eigh(xp.eye(3)), + 'eigvalsh': lambda: eigvalsh(xp.eye(3)), + 'inv': lambda: inv(xp.eye(3)), + 'matmul': lambda: matmul(xp.eye(3), xp.eye(3)), + 'matrix_norm': lambda: matrix_norm(xp.eye(3)), + 'matrix_power': lambda: matrix_power(xp.eye(3), 2), + 'matrix_rank': lambda: matrix_rank(xp.eye(3)), + 'matrix_transpose': lambda: matrix_transpose(xp.eye(3)), + 'outer': lambda: outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'pinv': lambda: pinv(xp.eye(3)), + 'qr': lambda: qr(xp.eye(3)), + 'slogdet': lambda: slogdet(xp.eye(3)), + 'solve': lambda: solve(xp.eye(3), xp.eye(3)), + 'svd': lambda: svd(xp.eye(3)), + 'svdvals': lambda: svdvals(xp.eye(3)), + 'tensordot': lambda: tensordot(xp.eye(3), xp.eye(3)), + 'trace': lambda: trace(xp.eye(3)), + 'vecdot': lambda: vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'vector_norm': lambda: vector_norm(xp.asarray([1., 2., 3.])), } assert set(linalg_examples) == set(xp.linalg.__all__) @@ -148,9 +192,10 @@ def test_boolean_indexing(): 'matrix_transpose': lambda: xp.matrix_transpose(xp.eye(3)), 'tensordot': lambda: xp.tensordot(xp.eye(3), xp.eye(3)), 'vecdot': lambda: xp.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'mT': lambda: xp.eye(3).mT, } -assert set(linalg_main_namespace_examples) == set(xp.__all__) & set(xp.linalg.__all__) +assert set(linalg_main_namespace_examples) == (set(xp.__all__) & set(xp.linalg.__all__)) | {"mT"} @pytest.mark.parametrize('func_name', linalg_examples.keys()) def test_linalg(func_name): @@ -173,20 +218,20 @@ def test_linalg(func_name): main_namespace_func() fft_examples = { - 'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])), - 'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])), - 'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])), - 'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), - 'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])), - 'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])), - 'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])), - 'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), - 'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])), - 'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])), - 'fftfreq': lambda: xp.fft.fftfreq(4), - 'rfftfreq': lambda: xp.fft.rfftfreq(4), - 'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])), - 'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])), + 'fft': lambda: fft(xp.asarray([0j, 1j, 0j, 0j])), + 'ifft': lambda: ifft(xp.asarray([0j, 1j, 0j, 0j])), + 'fftn': lambda: fftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'ifftn': lambda: ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'rfft': lambda: rfft(xp.asarray([0., 1., 0., 0.])), + 'irfft': lambda: irfft(xp.asarray([0j, 1j, 0j, 0j])), + 'rfftn': lambda: rfftn(xp.asarray([[0., 1.], [0., 0.]])), + 'irfftn': lambda: irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'hfft': lambda: hfft(xp.asarray([0j, 1j, 0j, 0j])), + 'ihfft': lambda: ihfft(xp.asarray([0., 1., 0., 0.])), + 'fftfreq': lambda: fftfreq(4), + 'rfftfreq': lambda: rfftfreq(4), + 'fftshift': lambda: fftshift(xp.asarray([0j, 1j, 0j, 0j])), + 'ifftshift': lambda: ifftshift(xp.asarray([0j, 1j, 0j, 0j])), } assert set(fft_examples) == set(xp.fft.__all__) @@ -203,3 +248,246 @@ def test_fft(func_name): set_array_api_strict_flags(enabled_extensions=('fft',)) func() + +api_version_2023_12_examples = { + '__array_namespace_info__': lambda: xp.__array_namespace_info__(), + # Test these functions directly to ensure they are properly decorated + 'capabilities': capabilities, + 'default_device': default_device, + 'default_dtypes': default_dtypes, + 'devices': devices, + 'dtypes': dtypes, + 'clip': lambda: xp.clip(xp.asarray([1, 2, 3]), 1, 2), + 'copysign': lambda: xp.copysign(xp.asarray([1., 2., 3.]), xp.asarray([-1., -1., -1.])), + 'cumulative_sum': lambda: xp.cumulative_sum(xp.asarray([1, 2, 3])), + 'hypot': lambda: xp.hypot(xp.asarray([3., 4.]), xp.asarray([4., 3.])), + 'maximum': lambda: xp.maximum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])), + 'minimum': lambda: xp.minimum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])), + 'moveaxis': lambda: xp.moveaxis(xp.ones((3, 3)), 0, 1), + 'repeat': lambda: xp.repeat(xp.asarray([1, 2, 3]), 3), + 'searchsorted': lambda: xp.searchsorted(xp.asarray([1, 2, 3]), xp.asarray([0, 1, 2, 3, 4])), + 'signbit': lambda: xp.signbit(xp.asarray([-1., 0., 1.])), + 'tile': lambda: xp.tile(xp.ones((3, 3)), (2, 3)), + 'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0), +} + +@pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) +def test_api_version_2023_12(func_name): + func = api_version_2023_12_examples[func_name] + + # By default, these functions should error + pytest.raises(RuntimeError, func) + + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + func() + + set_array_api_strict_flags(api_version='2022.12') + pytest.raises(RuntimeError, func) + +def test_disabled_extensions(): + # Test that xp.extension errors when an extension is disabled, and that + # xp.__all__ is updated properly. + + # First test that things are correct on the initial import. Since we have + # already called set_array_api_strict_flags many times throughout running + # the tests, we have to test this in a subprocess. + subprocess_tests = [('''\ +import array_api_strict + +array_api_strict.linalg # No error +array_api_strict.fft # No error +assert "linalg" in array_api_strict.__all__ +assert "fft" in array_api_strict.__all__ +assert len(array_api_strict.__all__) == len(set(array_api_strict.__all__)) +''', {}), +# Test that the initial population of __all__ works correctly +('''\ +from array_api_strict import * # No error +linalg # Should have been imported by the previous line +fft +''', {}), +('''\ +from array_api_strict import * # No error +linalg # Should have been imported by the previous line +assert 'fft' not in globals() +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "linalg"}), +('''\ +from array_api_strict import * # No error +fft # Should have been imported by the previous line +assert 'linalg' not in globals() +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "fft"}), +('''\ +from array_api_strict import * # No error +assert 'linalg' not in globals() +assert 'fft' not in globals() +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": ""}), +] + for test, env in subprocess_tests: + try: + subprocess.run([sys.executable, '-c', test], check=True, + capture_output=True, encoding='utf-8', env=env) + except subprocess.CalledProcessError as e: + print(e.stdout, end='') + # Ensure the exception is shown in the output log + raise AssertionError(e.stderr) + + assert 'linalg' in xp.__all__ + assert 'fft' in xp.__all__ + xp.linalg # No error + xp.fft # No error + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' in ns + assert 'fft' in ns + + set_array_api_strict_flags(enabled_extensions=('linalg',)) + assert 'linalg' in xp.__all__ + assert 'fft' not in xp.__all__ + xp.linalg # No error + pytest.raises(AttributeError, lambda: xp.fft) + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' in ns + assert 'fft' not in ns + + set_array_api_strict_flags(enabled_extensions=('fft',)) + assert 'linalg' not in xp.__all__ + assert 'fft' in xp.__all__ + pytest.raises(AttributeError, lambda: xp.linalg) + xp.fft # No error + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' not in ns + assert 'fft' in ns + + set_array_api_strict_flags(enabled_extensions=()) + assert 'linalg' not in xp.__all__ + assert 'fft' not in xp.__all__ + pytest.raises(AttributeError, lambda: xp.linalg) + pytest.raises(AttributeError, lambda: xp.fft) + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' not in ns + assert 'fft' not in ns + + +def test_environment_variables(): + # Test that the environment variables work as expected + subprocess_tests = [ + # ARRAY_API_STRICT_API_VERSION + ('''\ +import array_api_strict as xp +assert xp.__array_api_version__ == '2022.12' + +assert xp.get_array_api_strict_flags()['api_version'] == '2022.12' + +''', {}), + *[ + (f'''\ +import array_api_strict as xp +assert xp.__array_api_version__ == '{version}' + +assert xp.get_array_api_strict_flags()['api_version'] == '{version}' + +if {version} == '2021.12': + assert hasattr(xp, 'linalg') + assert not hasattr(xp, 'fft') + +''', {"ARRAY_API_STRICT_API_VERSION": version}) for version in ('2021.12', '2022.12', '2023.12')], + + # ARRAY_API_STRICT_BOOLEAN_INDEXING + ('''\ +import array_api_strict as xp + +a = xp.ones(3) +mask = xp.asarray([True, False, True]) + +assert xp.all(a[mask] == xp.asarray([1., 1.])) +assert xp.get_array_api_strict_flags()['boolean_indexing'] == True +''', {}), + *[(f'''\ +import array_api_strict as xp + +a = xp.ones(3) +mask = xp.asarray([True, False, True]) + +if {boolean_indexing}: + assert xp.all(a[mask] == xp.asarray([1., 1.])) +else: + try: + a[mask] + except RuntimeError: + pass + else: + assert False + +assert xp.get_array_api_strict_flags()['boolean_indexing'] == {boolean_indexing} +''', {"ARRAY_API_STRICT_BOOLEAN_INDEXING": boolean_indexing}) + for boolean_indexing in ('True', 'False')], + + # ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES + ('''\ +import array_api_strict as xp + +a = xp.ones(3) +xp.unique_all(a) + +assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == True +''', {}), + *[(f'''\ +import array_api_strict as xp + +a = xp.ones(3) +if {data_dependent_shapes}: + xp.unique_all(a) +else: + try: + xp.unique_all(a) + except RuntimeError: + pass + else: + assert False + +assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == {data_dependent_shapes} +''', {"ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES": data_dependent_shapes}) + for data_dependent_shapes in ('True', 'False')], + + # ARRAY_API_STRICT_ENABLED_EXTENSIONS + ('''\ +import array_api_strict as xp +assert hasattr(xp, 'linalg') +assert hasattr(xp, 'fft') + +assert xp.get_array_api_strict_flags()['enabled_extensions'] == ('linalg', 'fft') +''', {}), + *[(f'''\ +import array_api_strict as xp + +assert hasattr(xp, 'linalg') == ('linalg' in {extensions.split(',')}) +assert hasattr(xp, 'fft') == ('fft' in {extensions.split(',')}) + +assert sorted(xp.get_array_api_strict_flags()['enabled_extensions']) == {sorted(set(extensions.split(','))-{''})} +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": extensions}) + for extensions in ('', 'linalg', 'fft', 'linalg,fft')], + ] + + for test, env in subprocess_tests: + try: + subprocess.run([sys.executable, '-c', test], check=True, + capture_output=True, encoding='utf-8', env=env) + except subprocess.CalledProcessError as e: + print(e.stdout, end='') + # Ensure the exception is shown in the output log + raise AssertionError(f"""\ +STDOUT: +{e.stderr} + +STDERR: +{e.stderr} + +TEST: +{test} + +ENV: +{env}""") diff --git a/array_api_strict/tests/test_linalg.py b/array_api_strict/tests/test_linalg.py new file mode 100644 index 0000000..5e6cda2 --- /dev/null +++ b/array_api_strict/tests/test_linalg.py @@ -0,0 +1,133 @@ +import pytest + +from .._flags import set_array_api_strict_flags + +import array_api_strict as xp + +# TODO: Maybe all of these exceptions should be IndexError? + +# Technically this is linear_algebra, not linalg, but it's simpler to keep +# both of these tests together +def test_vecdot_2023_12(): + # Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >= + # 0 behavior (which is primarily kept for backwards compatibility). + + a = xp.ones((2, 3, 4, 5)) + b = xp.ones(( 3, 4, 1)) + + # 2022.12 behavior, which is to apply axis >= 0 after broadcasting + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) + assert xp.linalg.vecdot(a, b, axis=1).shape == (2, 4, 5) + assert xp.linalg.vecdot(a, b, axis=2).shape == (2, 3, 5) + # This is disallowed because the arrays must have the same values before + # broadcasting + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3)) + + # Out-of-bounds axes even after broadcasting + pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=4)) + pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=-5)) + + # negative axis behavior is unambiguous when it's within the bounds of + # both arrays before broadcasting + assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5) + assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5) + + # 2023.12 behavior, which is to only allow axis < 0 and axis >= + # min(x1.ndim, x2.ndim), which is unambiguous + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=1)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=2)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=4)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-5)) + + assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5) + assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5) + +@pytest.mark.parametrize('api_version', ['2021.12', '2022.12', '2023.12']) +def test_cross(api_version): + # This test tests everything that should be the same across all supported + # API versions. + + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = xp.ones((2, 4, 5, 3)) + b = xp.ones(( 4, 1, 3)) + assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3) + + a = xp.ones((2, 4, 3, 5)) + b = xp.ones(( 4, 3, 1)) + assert xp.linalg.cross(a, b, axis=-2).shape == (2, 4, 3, 5) + + # This is disallowed because the axes must equal 3 before broadcasting + a = xp.ones((3, 2, 3, 5)) + b = xp.ones(( 2, 1, 1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4)) + + # Out-of-bounds axes even after broadcasting + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4)) + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5)) + +@pytest.mark.parametrize('api_version', ['2021.12', '2022.12']) +def test_cross_2022_12(api_version): + # Test the 2022.12 axis >= 0 behavior, which is primarily kept for + # backwards compatibility. Note that unlike vecdot, array_api_strict + # cross() never implemented the "after broadcasting" axis behavior, but + # just reused NumPy cross(), which applies axes before broadcasting. + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = xp.ones((3, 2, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5) + + # ambiguous case + a = xp.ones(( 3, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5) + +def test_cross_2023_12(): + # 2023.12 behavior, which is to only allow axis < 0 and axis >= + # min(x1.ndim, x2.ndim), which is unambiguous + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + + a = xp.ones((3, 2, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0)) + + a = xp.ones(( 3, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + pytest.raises(ValueError, lambda: xp. linalg.cross(a, b, axis=0)) + + a = xp.ones((2, 4, 5, 3)) + b = xp.ones(( 4, 1, 3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=2)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4)) + + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4)) + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5)) + + assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3) diff --git a/array_api_strict/tests/test_manipulation_functions.py b/array_api_strict/tests/test_manipulation_functions.py index 70b42f3..9969651 100644 --- a/array_api_strict/tests/test_manipulation_functions.py +++ b/array_api_strict/tests/test_manipulation_functions.py @@ -25,7 +25,7 @@ def test_reshape_copy(): a = asarray(np.ones((2, 3))) b = reshape(a, (3, 2), copy=True) assert not np.shares_memory(a._array, b._array) - + a = asarray(np.ones((2, 3))) b = reshape(a, (3, 2), copy=False) assert np.shares_memory(a._array, b._array) diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py new file mode 100644 index 0000000..61e848c --- /dev/null +++ b/array_api_strict/tests/test_statistical_functions.py @@ -0,0 +1,29 @@ +import pytest + +from .._flags import set_array_api_strict_flags + +import array_api_strict as xp + +@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) +def test_sum_prod_trace_2023_12(func_name): + # sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes + # with dtype=None + if func_name == 'trace': + func = getattr(xp.linalg, func_name) + else: + func = getattr(xp, func_name) + + a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32) + a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64) + a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32) + + assert func(a_real).dtype == xp.float64 + assert func(a_complex).dtype == xp.complex128 + assert func(a_int).dtype == xp.int64 + + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + + assert func(a_real).dtype == xp.float32 + assert func(a_complex).dtype == xp.complex64 + assert func(a_int).dtype == xp.int64 diff --git a/docs/api.rst b/docs/api.rst index 15ce4e9..ed702dc 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -11,6 +11,8 @@ Array API Strict Flags .. currentmodule:: array_api_strict .. autofunction:: get_array_api_strict_flags + +.. _set_array_api_strict_flags: .. autofunction:: set_array_api_strict_flags .. autofunction:: reset_array_api_strict_flags .. autoclass:: ArrayAPIStrictFlags diff --git a/docs/index.md b/docs/index.md index 6e84efa..fc385d4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -15,9 +15,12 @@ libraries. Consuming library code should use the support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. -array-api-strict currently supports the 2022.12 version of the standard. -2023.12 support is planned and is tracked by [this -issue](https://github.com/data-apis/array-api-strict/issues/25). +array-api-strict currently supports the +[2022.12](https://data-apis.org/array-api/latest/changelog.html#v2022-12) +version of the standard. Experimental +[2023.12](https://data-apis.org/array-api/latest/changelog.html#v2023-12) +support is implemented, [but must be enabled with a +flag](set_array_api_strict_flags). ## Install @@ -179,9 +182,11 @@ issue, but this hasn't necessarily been tested thoroughly. function. array-api-strict currently implements all of these. In the future, [there may be a way to disable them](https://github.com/data-apis/array-api-strict/issues/7). -6. array-api-strict currently only supports the 2022.12 version of the array - API standard. [Support for 2023.12 is - planned](https://github.com/data-apis/array-api-strict/issues/25). +6. array-api-strict currently uses the 2022.12 version of the array API + standard. Support for 2023.12 is implemented but is still experimental and + not fully tested. It can be enabled with + [`array_api_strict.set_array_api_strict_flags(api_version='2023.12')`](set_array_api_strict_flags). + (numpy.array_api)= ## Relationship to `numpy.array_api` diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..0c84ee3 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +filterwarnings = error