Skip to content

release branch 2.3 #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions array_api_strict/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
"2021.12",
"2022.12",
"2023.12",
"2024.12"
)

draft_version = "2024.12"
draft_version = "2025.12"

API_VERSION = default_version = "2023.12"
API_VERSION = default_version = "2024.12"

BOOLEAN_INDEXING = True

Expand Down
11 changes: 6 additions & 5 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,10 @@ def test_array_keys_use_private_array():
def test_array_namespace():
a = ones((3, 3))
assert a.__array_namespace__() == array_api_strict
assert array_api_strict.__array_api_version__ == "2023.12"
assert array_api_strict.__array_api_version__ == "2024.12"

assert a.__array_namespace__(api_version=None) is array_api_strict
assert array_api_strict.__array_api_version__ == "2023.12"
assert array_api_strict.__array_api_version__ == "2024.12"

assert a.__array_namespace__(api_version="2022.12") is array_api_strict
assert array_api_strict.__array_api_version__ == "2022.12"
Expand All @@ -525,11 +525,12 @@ def test_array_namespace():
assert array_api_strict.__array_api_version__ == "2021.12"

with pytest.warns(UserWarning):
assert a.__array_namespace__(api_version="2024.12") is array_api_strict
assert array_api_strict.__array_api_version__ == "2024.12"
assert a.__array_namespace__(api_version="2025.12") is array_api_strict
assert array_api_strict.__array_api_version__ == "2025.12"


pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2025.12"))
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2026.12"))

def test_iter():
pytest.raises(TypeError, lambda: iter(asarray(3)))
Expand Down
13 changes: 6 additions & 7 deletions array_api_strict/tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,11 @@ def test_result_type_py_scalars(api_version):
with pytest.raises(TypeError):
result_type(int16, 3)
else:
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version=api_version)
set_array_api_strict_flags(api_version=api_version)

assert result_type(int8, 3) == int8
assert result_type(uint8, 3) == uint8
assert result_type(float64, 3) == float64
assert result_type(int8, 3) == int8
assert result_type(uint8, 3) == uint8
assert result_type(float64, 3) == float64

with pytest.raises(TypeError):
result_type(int64, True)
with pytest.raises(TypeError):
result_type(int64, True)
10 changes: 3 additions & 7 deletions array_api_strict/tests/test_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pytest import raises as assert_raises
from numpy.testing import suppress_warnings

import pytest

from .. import asarray, _elementwise_functions
from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
Expand Down Expand Up @@ -134,8 +133,7 @@ def _array_vals(dtypes):
yield asarray(1., dtype=d)

# Use the latest version of the standard so all functions are included
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version="2024.12")
set_array_api_strict_flags(api_version="2024.12")

for func_name, types in elementwise_function_input_types.items():
dtypes = _dtype_categories[types]
Expand Down Expand Up @@ -171,8 +169,7 @@ def _array_vals():
yield asarray(1.0, dtype=d)

# Use the latest version of the standard so all functions are included
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version="2024.12")
set_array_api_strict_flags(api_version="2024.12")

for x in _array_vals():
for func_name, types in elementwise_function_input_types.items():
Expand Down Expand Up @@ -216,8 +213,7 @@ def test_scalars():
# arguments, and reject (scalar, scalar) arguments.

# Use the latest version of the standard so that scalars are actually allowed
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version="2024.12")
set_array_api_strict_flags(api_version="2024.12")

def _array_vals():
for d in _integer_dtypes:
Expand Down
45 changes: 27 additions & 18 deletions array_api_strict/tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def test_flag_defaults():
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2023.12',
'api_version': '2024.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft'),
Expand All @@ -36,7 +36,7 @@ def test_reset_flags():
reset_array_api_strict_flags()
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2023.12',
'api_version': '2024.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft'),
Expand All @@ -47,15 +47,15 @@ def test_setting_flags():
set_array_api_strict_flags(data_dependent_shapes=False)
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2023.12',
'api_version': '2024.12',
'boolean_indexing': True,
'data_dependent_shapes': False,
'enabled_extensions': ('linalg', 'fft'),
}
set_array_api_strict_flags(enabled_extensions=('fft',))
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2023.12',
'api_version': '2024.12',
'boolean_indexing': True,
'data_dependent_shapes': False,
'enabled_extensions': ('fft',),
Expand Down Expand Up @@ -98,15 +98,26 @@ def test_flags_api_version_2023_12():
}

def test_flags_api_version_2024_12():
# Make sure setting the version to 2024.12 issues a warning.
set_array_api_strict_flags(api_version='2024.12')
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2024.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft'),
}


def test_flags_api_version_2025_12():
# Make sure setting the version to 2025.12 issues a warning.
with pytest.warns(UserWarning) as record:
set_array_api_strict_flags(api_version='2024.12')
set_array_api_strict_flags(api_version='2025.12')
assert len(record) == 1
assert '2024.12' in str(record[0].message)
assert '2025.12' in str(record[0].message)
assert 'draft' in str(record[0].message)
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2024.12',
'api_version': '2025.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft'),
Expand All @@ -125,9 +136,12 @@ def test_setting_flags_invalid():

def test_api_version():
# Test defaults
assert xp.__array_api_version__ == '2023.12'
assert xp.__array_api_version__ == '2024.12'

# Test setting the version
set_array_api_strict_flags(api_version='2023.12')
assert xp.__array_api_version__ == '2023.12'

set_array_api_strict_flags(api_version='2022.12')
assert xp.__array_api_version__ == '2022.12'

Expand Down Expand Up @@ -315,20 +329,15 @@ def test_api_version_2023_12(func_name):
def test_api_version_2024_12(func_name):
func = api_version_2024_12_examples[func_name]

# By default, these functions should error
pytest.raises(RuntimeError, func)
# By default, these functions should not error
func()

# In 2022.12 and 2023.12, these functions should error
set_array_api_strict_flags(api_version='2022.12')
pytest.raises(RuntimeError, func)
set_array_api_strict_flags(api_version='2023.12')
pytest.raises(RuntimeError, func)

# They should not error in 2024.12
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2024.12')
func()

# Test the behavior gets updated properly
set_array_api_strict_flags(api_version='2023.12')
pytest.raises(RuntimeError, func)
Expand Down Expand Up @@ -435,9 +444,9 @@ def test_environment_variables():
# ARRAY_API_STRICT_API_VERSION
('''\
import array_api_strict as xp
assert xp.__array_api_version__ == '2023.12'
assert xp.__array_api_version__ == '2024.12'

assert xp.get_array_api_strict_flags()['api_version'] == '2023.12'
assert xp.get_array_api_strict_flags()['api_version'] == '2024.12'

''', {}),
*[
Expand Down
28 changes: 11 additions & 17 deletions array_api_strict/tests/test_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,22 @@
import array_api_strict as xp

from array_api_strict import ArrayAPIStrictFlags
from array_api_strict._flags import draft_version


def test_where_with_scalars():
x = xp.asarray([1, 2, 3, 1])

# Versions up to and including 2023.12 don't support scalar arguments
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
xp.where(x == 1, 42, 44)
with ArrayAPIStrictFlags(api_version='2023.12'):
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
xp.where(x == 1, 42, 44)

# Versions after 2023.12 support scalar arguments
with (pytest.warns(
UserWarning,
match="The 2024.12 version of the array API specification is in draft status"
),
ArrayAPIStrictFlags(api_version=draft_version),
):
x_where = xp.where(x == 1, xp.asarray(42), 44)

expected = xp.asarray([42, 44, 44, 42])
assert xp.all(x_where == expected)

# The spec does not allow both x1 and x2 to be scalars
with pytest.raises(ValueError, match="One of"):
xp.where(x == 1, 42, 44)
x_where = xp.where(x == 1, xp.asarray(42), 44)

expected = xp.asarray([42, 44, 44, 42])
assert xp.all(x_where == expected)

# The spec does not allow both x1 and x2 to be scalars
with pytest.raises(ValueError, match="One of"):
xp.where(x == 1, 42, 44)
10 changes: 4 additions & 6 deletions array_api_strict/tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import cmath
import pytest

from .._flags import set_array_api_strict_flags
from .._flags import set_array_api_strict_flags, ArrayAPIStrictFlags

import array_api_strict as xp

Expand Down Expand Up @@ -44,12 +44,10 @@ def test_sum_prod_trace_2023_12(func_name):
def test_mean_complex():
a = xp.asarray([1j, 2j, 3j])

set_array_api_strict_flags(api_version='2023.12')
with pytest.raises(TypeError):
xp.mean(a)
with ArrayAPIStrictFlags(api_version='2023.12'):
with pytest.raises(TypeError):
xp.mean(a)

with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2024.12')
m = xp.mean(a)
assert cmath.isclose(complex(m), 2j)

Expand Down
68 changes: 68 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,73 @@
# Changelog

## 2.3 (2025-XX-XX)

### Major Changes

- The default version of the array API standard is now 2024.12. Previous versions can
still be enabled via the [flags API](array-api-strict-flags).

Note that this support is still relatively untested. Please [report any
issues](https://github.com/data-apis/array-api-strict/issues) you find.

- Binary elementwise functions now accept python scalars: the only requirement is that
at least one of the arguments must be an array; the other argument may be either
a python scalar or an array. Python scalars are handled in accordance with the
type promotion rules, as specified by the standard.
This change unifies the behavior of binary functions and their matching operators,
(where available), such as `multiply(x1, x2)` and `__mul__(self, other)`.

`where` accepts arrays or scalars as its 2nd and 3rd arguments, `x1` and `x2`.
The first argument, `condition`, must be an array.

`result_type` accepts arrays and scalars and computes the result dtype according
to the promotion rules.

- Ergonomics of working with complex values has been improved:

- binary operators accept complex scalars and real arrays and preserve the floating point
precision: `1j*f32_array` returns a `complex64` array
- `mean` accepts complex floating-point arrays.
- `real` and `conj` accept numeric arguments, including real floating point data.
Note that `imag` still requires its input to be a complex array.

- The following functions, new in the 2024.12 standard revision, are implemented:

- `count_nonzero`
- `cumulative_prod`

- `fftfreq` and `rfftfreq` functions accept a new `dtype` argument to control the
the data type of their output.


### Minor Changes

- `vecdot` now conjugates the first argument, in accordance with the standard.

- `astype` now raises a `TypeError` instead of casting a complex floating-point
array to a real-valued or an integral data type.

- `where` requires that its first argument, `condition` has a boolean data dtype,
and raises a `TypeError` otherwise.

- `isdtype` raises a `TypeError` is its argument is not a dtype object.

- arrays created with `from_dlpack` now correctly set their `device` attribute.

- the build system now uses `pyproject.toml`, not `setup.py`.

### Contributors

The following users contributed to this release:

Aaron Meurer
Clément Robert
Guido Imperiale
Evgeni Burovski
Lucas Colley
Tim Head


## 2.2 (2024-11-11)

### Major Changes
Expand Down
Loading