Skip to content

Commit 2d56d32

Browse files
authored
Merge pull request #123 from ev-br/release_2.3
release branch 2.3
2 parents 4b8fbef + a86d0bf commit 2d56d32

8 files changed

+128
-62
lines changed

array_api_strict/_flags.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
"2021.12",
2323
"2022.12",
2424
"2023.12",
25+
"2024.12"
2526
)
2627

27-
draft_version = "2024.12"
28+
draft_version = "2025.12"
2829

29-
API_VERSION = default_version = "2023.12"
30+
API_VERSION = default_version = "2024.12"
3031

3132
BOOLEAN_INDEXING = True
3233

array_api_strict/tests/test_array_object.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -509,10 +509,10 @@ def test_array_keys_use_private_array():
509509
def test_array_namespace():
510510
a = ones((3, 3))
511511
assert a.__array_namespace__() == array_api_strict
512-
assert array_api_strict.__array_api_version__ == "2023.12"
512+
assert array_api_strict.__array_api_version__ == "2024.12"
513513

514514
assert a.__array_namespace__(api_version=None) is array_api_strict
515-
assert array_api_strict.__array_api_version__ == "2023.12"
515+
assert array_api_strict.__array_api_version__ == "2024.12"
516516

517517
assert a.__array_namespace__(api_version="2022.12") is array_api_strict
518518
assert array_api_strict.__array_api_version__ == "2022.12"
@@ -525,11 +525,12 @@ def test_array_namespace():
525525
assert array_api_strict.__array_api_version__ == "2021.12"
526526

527527
with pytest.warns(UserWarning):
528-
assert a.__array_namespace__(api_version="2024.12") is array_api_strict
529-
assert array_api_strict.__array_api_version__ == "2024.12"
528+
assert a.__array_namespace__(api_version="2025.12") is array_api_strict
529+
assert array_api_strict.__array_api_version__ == "2025.12"
530+
530531

531532
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
532-
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2025.12"))
533+
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2026.12"))
533534

534535
def test_iter():
535536
pytest.raises(TypeError, lambda: iter(asarray(3)))

array_api_strict/tests/test_data_type_functions.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,11 @@ def test_result_type_py_scalars(api_version):
8080
with pytest.raises(TypeError):
8181
result_type(int16, 3)
8282
else:
83-
with pytest.warns(UserWarning):
84-
set_array_api_strict_flags(api_version=api_version)
83+
set_array_api_strict_flags(api_version=api_version)
8584

86-
assert result_type(int8, 3) == int8
87-
assert result_type(uint8, 3) == uint8
88-
assert result_type(float64, 3) == float64
85+
assert result_type(int8, 3) == int8
86+
assert result_type(uint8, 3) == uint8
87+
assert result_type(float64, 3) == float64
8988

90-
with pytest.raises(TypeError):
91-
result_type(int64, True)
89+
with pytest.raises(TypeError):
90+
result_type(int64, True)

array_api_strict/tests/test_elementwise_functions.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from pytest import raises as assert_raises
44
from numpy.testing import suppress_warnings
55

6-
import pytest
76

87
from .. import asarray, _elementwise_functions
98
from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
@@ -134,8 +133,7 @@ def _array_vals(dtypes):
134133
yield asarray(1., dtype=d)
135134

136135
# Use the latest version of the standard so all functions are included
137-
with pytest.warns(UserWarning):
138-
set_array_api_strict_flags(api_version="2024.12")
136+
set_array_api_strict_flags(api_version="2024.12")
139137

140138
for func_name, types in elementwise_function_input_types.items():
141139
dtypes = _dtype_categories[types]
@@ -171,8 +169,7 @@ def _array_vals():
171169
yield asarray(1.0, dtype=d)
172170

173171
# Use the latest version of the standard so all functions are included
174-
with pytest.warns(UserWarning):
175-
set_array_api_strict_flags(api_version="2024.12")
172+
set_array_api_strict_flags(api_version="2024.12")
176173

177174
for x in _array_vals():
178175
for func_name, types in elementwise_function_input_types.items():
@@ -216,8 +213,7 @@ def test_scalars():
216213
# arguments, and reject (scalar, scalar) arguments.
217214

218215
# Use the latest version of the standard so that scalars are actually allowed
219-
with pytest.warns(UserWarning):
220-
set_array_api_strict_flags(api_version="2024.12")
216+
set_array_api_strict_flags(api_version="2024.12")
221217

222218
def _array_vals():
223219
for d in _integer_dtypes:

array_api_strict/tests/test_flags.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
def test_flag_defaults():
2020
flags = get_array_api_strict_flags()
2121
assert flags == {
22-
'api_version': '2023.12',
22+
'api_version': '2024.12',
2323
'boolean_indexing': True,
2424
'data_dependent_shapes': True,
2525
'enabled_extensions': ('linalg', 'fft'),
@@ -36,7 +36,7 @@ def test_reset_flags():
3636
reset_array_api_strict_flags()
3737
flags = get_array_api_strict_flags()
3838
assert flags == {
39-
'api_version': '2023.12',
39+
'api_version': '2024.12',
4040
'boolean_indexing': True,
4141
'data_dependent_shapes': True,
4242
'enabled_extensions': ('linalg', 'fft'),
@@ -47,15 +47,15 @@ def test_setting_flags():
4747
set_array_api_strict_flags(data_dependent_shapes=False)
4848
flags = get_array_api_strict_flags()
4949
assert flags == {
50-
'api_version': '2023.12',
50+
'api_version': '2024.12',
5151
'boolean_indexing': True,
5252
'data_dependent_shapes': False,
5353
'enabled_extensions': ('linalg', 'fft'),
5454
}
5555
set_array_api_strict_flags(enabled_extensions=('fft',))
5656
flags = get_array_api_strict_flags()
5757
assert flags == {
58-
'api_version': '2023.12',
58+
'api_version': '2024.12',
5959
'boolean_indexing': True,
6060
'data_dependent_shapes': False,
6161
'enabled_extensions': ('fft',),
@@ -98,15 +98,26 @@ def test_flags_api_version_2023_12():
9898
}
9999

100100
def test_flags_api_version_2024_12():
101-
# Make sure setting the version to 2024.12 issues a warning.
101+
set_array_api_strict_flags(api_version='2024.12')
102+
flags = get_array_api_strict_flags()
103+
assert flags == {
104+
'api_version': '2024.12',
105+
'boolean_indexing': True,
106+
'data_dependent_shapes': True,
107+
'enabled_extensions': ('linalg', 'fft'),
108+
}
109+
110+
111+
def test_flags_api_version_2025_12():
112+
# Make sure setting the version to 2025.12 issues a warning.
102113
with pytest.warns(UserWarning) as record:
103-
set_array_api_strict_flags(api_version='2024.12')
114+
set_array_api_strict_flags(api_version='2025.12')
104115
assert len(record) == 1
105-
assert '2024.12' in str(record[0].message)
116+
assert '2025.12' in str(record[0].message)
106117
assert 'draft' in str(record[0].message)
107118
flags = get_array_api_strict_flags()
108119
assert flags == {
109-
'api_version': '2024.12',
120+
'api_version': '2025.12',
110121
'boolean_indexing': True,
111122
'data_dependent_shapes': True,
112123
'enabled_extensions': ('linalg', 'fft'),
@@ -125,9 +136,12 @@ def test_setting_flags_invalid():
125136

126137
def test_api_version():
127138
# Test defaults
128-
assert xp.__array_api_version__ == '2023.12'
139+
assert xp.__array_api_version__ == '2024.12'
129140

130141
# Test setting the version
142+
set_array_api_strict_flags(api_version='2023.12')
143+
assert xp.__array_api_version__ == '2023.12'
144+
131145
set_array_api_strict_flags(api_version='2022.12')
132146
assert xp.__array_api_version__ == '2022.12'
133147

@@ -315,20 +329,15 @@ def test_api_version_2023_12(func_name):
315329
def test_api_version_2024_12(func_name):
316330
func = api_version_2024_12_examples[func_name]
317331

318-
# By default, these functions should error
319-
pytest.raises(RuntimeError, func)
332+
# By default, these functions should not error
333+
func()
320334

321335
# In 2022.12 and 2023.12, these functions should error
322336
set_array_api_strict_flags(api_version='2022.12')
323337
pytest.raises(RuntimeError, func)
324338
set_array_api_strict_flags(api_version='2023.12')
325339
pytest.raises(RuntimeError, func)
326340

327-
# They should not error in 2024.12
328-
with pytest.warns(UserWarning):
329-
set_array_api_strict_flags(api_version='2024.12')
330-
func()
331-
332341
# Test the behavior gets updated properly
333342
set_array_api_strict_flags(api_version='2023.12')
334343
pytest.raises(RuntimeError, func)
@@ -435,9 +444,9 @@ def test_environment_variables():
435444
# ARRAY_API_STRICT_API_VERSION
436445
('''\
437446
import array_api_strict as xp
438-
assert xp.__array_api_version__ == '2023.12'
447+
assert xp.__array_api_version__ == '2024.12'
439448
440-
assert xp.get_array_api_strict_flags()['api_version'] == '2023.12'
449+
assert xp.get_array_api_strict_flags()['api_version'] == '2024.12'
441450
442451
''', {}),
443452
*[

array_api_strict/tests/test_searching_functions.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,22 @@
33
import array_api_strict as xp
44

55
from array_api_strict import ArrayAPIStrictFlags
6-
from array_api_strict._flags import draft_version
76

87

98
def test_where_with_scalars():
109
x = xp.asarray([1, 2, 3, 1])
1110

1211
# Versions up to and including 2023.12 don't support scalar arguments
13-
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
14-
xp.where(x == 1, 42, 44)
12+
with ArrayAPIStrictFlags(api_version='2023.12'):
13+
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
14+
xp.where(x == 1, 42, 44)
1515

1616
# Versions after 2023.12 support scalar arguments
17-
with (pytest.warns(
18-
UserWarning,
19-
match="The 2024.12 version of the array API specification is in draft status"
20-
),
21-
ArrayAPIStrictFlags(api_version=draft_version),
22-
):
23-
x_where = xp.where(x == 1, xp.asarray(42), 44)
24-
25-
expected = xp.asarray([42, 44, 44, 42])
26-
assert xp.all(x_where == expected)
27-
28-
# The spec does not allow both x1 and x2 to be scalars
29-
with pytest.raises(ValueError, match="One of"):
30-
xp.where(x == 1, 42, 44)
17+
x_where = xp.where(x == 1, xp.asarray(42), 44)
18+
19+
expected = xp.asarray([42, 44, 44, 42])
20+
assert xp.all(x_where == expected)
21+
22+
# The spec does not allow both x1 and x2 to be scalars
23+
with pytest.raises(ValueError, match="One of"):
24+
xp.where(x == 1, 42, 44)

array_api_strict/tests/test_statistical_functions.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import cmath
22
import pytest
33

4-
from .._flags import set_array_api_strict_flags
4+
from .._flags import set_array_api_strict_flags, ArrayAPIStrictFlags
55

66
import array_api_strict as xp
77

@@ -44,12 +44,10 @@ def test_sum_prod_trace_2023_12(func_name):
4444
def test_mean_complex():
4545
a = xp.asarray([1j, 2j, 3j])
4646

47-
set_array_api_strict_flags(api_version='2023.12')
48-
with pytest.raises(TypeError):
49-
xp.mean(a)
47+
with ArrayAPIStrictFlags(api_version='2023.12'):
48+
with pytest.raises(TypeError):
49+
xp.mean(a)
5050

51-
with pytest.warns(UserWarning):
52-
set_array_api_strict_flags(api_version='2024.12')
5351
m = xp.mean(a)
5452
assert cmath.isclose(complex(m), 2j)
5553

docs/changelog.md

+68
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,73 @@
11
# Changelog
22

3+
## 2.3 (2025-XX-XX)
4+
5+
### Major Changes
6+
7+
- The default version of the array API standard is now 2024.12. Previous versions can
8+
still be enabled via the [flags API](array-api-strict-flags).
9+
10+
Note that this support is still relatively untested. Please [report any
11+
issues](https://github.com/data-apis/array-api-strict/issues) you find.
12+
13+
- Binary elementwise functions now accept python scalars: the only requirement is that
14+
at least one of the arguments must be an array; the other argument may be either
15+
a python scalar or an array. Python scalars are handled in accordance with the
16+
type promotion rules, as specified by the standard.
17+
This change unifies the behavior of binary functions and their matching operators,
18+
(where available), such as `multiply(x1, x2)` and `__mul__(self, other)`.
19+
20+
`where` accepts arrays or scalars as its 2nd and 3rd arguments, `x1` and `x2`.
21+
The first argument, `condition`, must be an array.
22+
23+
`result_type` accepts arrays and scalars and computes the result dtype according
24+
to the promotion rules.
25+
26+
- Ergonomics of working with complex values has been improved:
27+
28+
- binary operators accept complex scalars and real arrays and preserve the floating point
29+
precision: `1j*f32_array` returns a `complex64` array
30+
- `mean` accepts complex floating-point arrays.
31+
- `real` and `conj` accept numeric arguments, including real floating point data.
32+
Note that `imag` still requires its input to be a complex array.
33+
34+
- The following functions, new in the 2024.12 standard revision, are implemented:
35+
36+
- `count_nonzero`
37+
- `cumulative_prod`
38+
39+
- `fftfreq` and `rfftfreq` functions accept a new `dtype` argument to control the
40+
the data type of their output.
41+
42+
43+
### Minor Changes
44+
45+
- `vecdot` now conjugates the first argument, in accordance with the standard.
46+
47+
- `astype` now raises a `TypeError` instead of casting a complex floating-point
48+
array to a real-valued or an integral data type.
49+
50+
- `where` requires that its first argument, `condition` has a boolean data dtype,
51+
and raises a `TypeError` otherwise.
52+
53+
- `isdtype` raises a `TypeError` is its argument is not a dtype object.
54+
55+
- arrays created with `from_dlpack` now correctly set their `device` attribute.
56+
57+
- the build system now uses `pyproject.toml`, not `setup.py`.
58+
59+
### Contributors
60+
61+
The following users contributed to this release:
62+
63+
Aaron Meurer
64+
Clément Robert
65+
Guido Imperiale
66+
Evgeni Burovski
67+
Lucas Colley
68+
Tim Head
69+
70+
371
## 2.2 (2024-11-11)
472

573
### Major Changes

0 commit comments

Comments
 (0)