From fc8e7314a55a6d1207bc6ae1ea5b17aa8fdb363e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 24 Feb 2025 10:36:31 +0100 Subject: [PATCH 1/3] ENH: set the default version to 2024.12 And adapt test_flags accordingly. --- array_api_strict/_flags.py | 5 ++-- array_api_strict/tests/test_flags.py | 45 +++++++++++++++++----------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 279b0e7..3fce8a0 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -22,11 +22,12 @@ "2021.12", "2022.12", "2023.12", + "2024.12" ) -draft_version = "2024.12" +draft_version = "2025.12" -API_VERSION = default_version = "2023.12" +API_VERSION = default_version = "2024.12" BOOLEAN_INDEXING = True diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index dcfc20d..764ca77 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -19,7 +19,7 @@ def test_flag_defaults(): flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2023.12', + 'api_version': '2024.12', 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), @@ -36,7 +36,7 @@ def test_reset_flags(): reset_array_api_strict_flags() flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2023.12', + 'api_version': '2024.12', 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), @@ -47,7 +47,7 @@ def test_setting_flags(): set_array_api_strict_flags(data_dependent_shapes=False) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2023.12', + 'api_version': '2024.12', 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('linalg', 'fft'), @@ -55,7 +55,7 @@ def test_setting_flags(): set_array_api_strict_flags(enabled_extensions=('fft',)) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2023.12', + 'api_version': '2024.12', 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('fft',), @@ -98,15 +98,26 @@ def test_flags_api_version_2023_12(): } def test_flags_api_version_2024_12(): - # Make sure setting the version to 2024.12 issues a warning. + set_array_api_strict_flags(api_version='2024.12') + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2024.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + + +def test_flags_api_version_2025_12(): + # Make sure setting the version to 2025.12 issues a warning. with pytest.warns(UserWarning) as record: - set_array_api_strict_flags(api_version='2024.12') + set_array_api_strict_flags(api_version='2025.12') assert len(record) == 1 - assert '2024.12' in str(record[0].message) + assert '2025.12' in str(record[0].message) assert 'draft' in str(record[0].message) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2024.12', + 'api_version': '2025.12', 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), @@ -125,9 +136,12 @@ def test_setting_flags_invalid(): def test_api_version(): # Test defaults - assert xp.__array_api_version__ == '2023.12' + assert xp.__array_api_version__ == '2024.12' # Test setting the version + set_array_api_strict_flags(api_version='2023.12') + assert xp.__array_api_version__ == '2023.12' + set_array_api_strict_flags(api_version='2022.12') assert xp.__array_api_version__ == '2022.12' @@ -315,8 +329,8 @@ def test_api_version_2023_12(func_name): def test_api_version_2024_12(func_name): func = api_version_2024_12_examples[func_name] - # By default, these functions should error - pytest.raises(RuntimeError, func) + # By default, these functions should not error + func() # In 2022.12 and 2023.12, these functions should error set_array_api_strict_flags(api_version='2022.12') @@ -324,11 +338,6 @@ def test_api_version_2024_12(func_name): set_array_api_strict_flags(api_version='2023.12') pytest.raises(RuntimeError, func) - # They should not error in 2024.12 - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2024.12') - func() - # Test the behavior gets updated properly set_array_api_strict_flags(api_version='2023.12') pytest.raises(RuntimeError, func) @@ -435,9 +444,9 @@ def test_environment_variables(): # ARRAY_API_STRICT_API_VERSION ('''\ import array_api_strict as xp -assert xp.__array_api_version__ == '2023.12' +assert xp.__array_api_version__ == '2024.12' -assert xp.get_array_api_strict_flags()['api_version'] == '2023.12' +assert xp.get_array_api_strict_flags()['api_version'] == '2024.12' ''', {}), *[ From 59a9ce726b94c8a832e9a0396f71025a40f2a371 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 24 Feb 2025 10:58:58 +0100 Subject: [PATCH 2/3] TST: update tests for 2024.12 revision (warnings, defaults) 2024.12 features do not emit warnings by default. --- array_api_strict/tests/test_array_object.py | 11 ++++---- .../tests/test_data_type_functions.py | 13 ++++----- .../tests/test_elementwise_functions.py | 10 ++----- .../tests/test_searching_functions.py | 28 ++++++++----------- .../tests/test_statistical_functions.py | 10 +++---- 5 files changed, 30 insertions(+), 42 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index ef76c28..e24a40f 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -509,10 +509,10 @@ def test_array_keys_use_private_array(): def test_array_namespace(): a = ones((3, 3)) assert a.__array_namespace__() == array_api_strict - assert array_api_strict.__array_api_version__ == "2023.12" + assert array_api_strict.__array_api_version__ == "2024.12" assert a.__array_namespace__(api_version=None) is array_api_strict - assert array_api_strict.__array_api_version__ == "2023.12" + assert array_api_strict.__array_api_version__ == "2024.12" assert a.__array_namespace__(api_version="2022.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2022.12" @@ -525,11 +525,12 @@ def test_array_namespace(): assert array_api_strict.__array_api_version__ == "2021.12" with pytest.warns(UserWarning): - assert a.__array_namespace__(api_version="2024.12") is array_api_strict - assert array_api_strict.__array_api_version__ == "2024.12" + assert a.__array_namespace__(api_version="2025.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2025.12" + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2025.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2026.12")) def test_iter(): pytest.raises(TypeError, lambda: iter(asarray(3))) diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 863d3d4..919c0b4 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -80,12 +80,11 @@ def test_result_type_py_scalars(api_version): with pytest.raises(TypeError): result_type(int16, 3) else: - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version=api_version) + set_array_api_strict_flags(api_version=api_version) - assert result_type(int8, 3) == int8 - assert result_type(uint8, 3) == uint8 - assert result_type(float64, 3) == float64 + assert result_type(int8, 3) == int8 + assert result_type(uint8, 3) == uint8 + assert result_type(float64, 3) == float64 - with pytest.raises(TypeError): - result_type(int64, True) + with pytest.raises(TypeError): + result_type(int64, True) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 93078ed..99596b4 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -3,7 +3,6 @@ from pytest import raises as assert_raises from numpy.testing import suppress_warnings -import pytest from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift @@ -134,8 +133,7 @@ def _array_vals(dtypes): yield asarray(1., dtype=d) # Use the latest version of the standard so all functions are included - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version="2024.12") + set_array_api_strict_flags(api_version="2024.12") for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] @@ -171,8 +169,7 @@ def _array_vals(): yield asarray(1.0, dtype=d) # Use the latest version of the standard so all functions are included - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version="2024.12") + set_array_api_strict_flags(api_version="2024.12") for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): @@ -216,8 +213,7 @@ def test_scalars(): # arguments, and reject (scalar, scalar) arguments. # Use the latest version of the standard so that scalars are actually allowed - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version="2024.12") + set_array_api_strict_flags(api_version="2024.12") def _array_vals(): for d in _integer_dtypes: diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index dfb3fe7..0e54d5f 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -3,28 +3,22 @@ import array_api_strict as xp from array_api_strict import ArrayAPIStrictFlags -from array_api_strict._flags import draft_version def test_where_with_scalars(): x = xp.asarray([1, 2, 3, 1]) # Versions up to and including 2023.12 don't support scalar arguments - with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): - xp.where(x == 1, 42, 44) + with ArrayAPIStrictFlags(api_version='2023.12'): + with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): + xp.where(x == 1, 42, 44) # Versions after 2023.12 support scalar arguments - with (pytest.warns( - UserWarning, - match="The 2024.12 version of the array API specification is in draft status" - ), - ArrayAPIStrictFlags(api_version=draft_version), - ): - x_where = xp.where(x == 1, xp.asarray(42), 44) - - expected = xp.asarray([42, 44, 44, 42]) - assert xp.all(x_where == expected) - - # The spec does not allow both x1 and x2 to be scalars - with pytest.raises(ValueError, match="One of"): - xp.where(x == 1, 42, 44) + x_where = xp.where(x == 1, xp.asarray(42), 44) + + expected = xp.asarray([42, 44, 44, 42]) + assert xp.all(x_where == expected) + + # The spec does not allow both x1 and x2 to be scalars + with pytest.raises(ValueError, match="One of"): + xp.where(x == 1, 42, 44) diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py index c97670d..d702b17 100644 --- a/array_api_strict/tests/test_statistical_functions.py +++ b/array_api_strict/tests/test_statistical_functions.py @@ -1,7 +1,7 @@ import cmath import pytest -from .._flags import set_array_api_strict_flags +from .._flags import set_array_api_strict_flags, ArrayAPIStrictFlags import array_api_strict as xp @@ -44,12 +44,10 @@ def test_sum_prod_trace_2023_12(func_name): def test_mean_complex(): a = xp.asarray([1j, 2j, 3j]) - set_array_api_strict_flags(api_version='2023.12') - with pytest.raises(TypeError): - xp.mean(a) + with ArrayAPIStrictFlags(api_version='2023.12'): + with pytest.raises(TypeError): + xp.mean(a) - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2024.12') m = xp.mean(a) assert cmath.isclose(complex(m), 2j) From a86d0bfe4b7b70ddc8f141b8ccf2c1cacde58bf2 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 2 Feb 2025 15:01:15 +0100 Subject: [PATCH 3/3] DOC: update the changelog for the 2.3 release --- docs/changelog.md | 68 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index d33dc24..7f6be2c 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,73 @@ # Changelog +## 2.3 (2025-XX-XX) + +### Major Changes + +- The default version of the array API standard is now 2024.12. Previous versions can + still be enabled via the [flags API](array-api-strict-flags). + + Note that this support is still relatively untested. Please [report any + issues](https://github.com/data-apis/array-api-strict/issues) you find. + +- Binary elementwise functions now accept python scalars: the only requirement is that + at least one of the arguments must be an array; the other argument may be either + a python scalar or an array. Python scalars are handled in accordance with the + type promotion rules, as specified by the standard. + This change unifies the behavior of binary functions and their matching operators, + (where available), such as `multiply(x1, x2)` and `__mul__(self, other)`. + + `where` accepts arrays or scalars as its 2nd and 3rd arguments, `x1` and `x2`. + The first argument, `condition`, must be an array. + + `result_type` accepts arrays and scalars and computes the result dtype according + to the promotion rules. + +- Ergonomics of working with complex values has been improved: + + - binary operators accept complex scalars and real arrays and preserve the floating point + precision: `1j*f32_array` returns a `complex64` array + - `mean` accepts complex floating-point arrays. + - `real` and `conj` accept numeric arguments, including real floating point data. + Note that `imag` still requires its input to be a complex array. + +- The following functions, new in the 2024.12 standard revision, are implemented: + + - `count_nonzero` + - `cumulative_prod` + +- `fftfreq` and `rfftfreq` functions accept a new `dtype` argument to control the + the data type of their output. + + +### Minor Changes + +- `vecdot` now conjugates the first argument, in accordance with the standard. + +- `astype` now raises a `TypeError` instead of casting a complex floating-point + array to a real-valued or an integral data type. + +- `where` requires that its first argument, `condition` has a boolean data dtype, + and raises a `TypeError` otherwise. + +- `isdtype` raises a `TypeError` is its argument is not a dtype object. + +- arrays created with `from_dlpack` now correctly set their `device` attribute. + +- the build system now uses `pyproject.toml`, not `setup.py`. + +### Contributors + +The following users contributed to this release: + +Aaron Meurer +Clément Robert +Guido Imperiale +Evgeni Burovski +Lucas Colley +Tim Head + + ## 2.2 (2024-11-11) ### Major Changes