Skip to content

Commit 729f32c

Browse files
committed
TST: update tests for 2024.12 revision (warnings, defaults)
2024.12 features do not emit warnings by default.
1 parent 755a285 commit 729f32c

5 files changed

+30
-42
lines changed

array_api_strict/tests/test_array_object.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -459,10 +459,10 @@ def test_array_keys_use_private_array():
459459
def test_array_namespace():
460460
a = ones((3, 3))
461461
assert a.__array_namespace__() == array_api_strict
462-
assert array_api_strict.__array_api_version__ == "2023.12"
462+
assert array_api_strict.__array_api_version__ == "2024.12"
463463

464464
assert a.__array_namespace__(api_version=None) is array_api_strict
465-
assert array_api_strict.__array_api_version__ == "2023.12"
465+
assert array_api_strict.__array_api_version__ == "2024.12"
466466

467467
assert a.__array_namespace__(api_version="2022.12") is array_api_strict
468468
assert array_api_strict.__array_api_version__ == "2022.12"
@@ -475,11 +475,12 @@ def test_array_namespace():
475475
assert array_api_strict.__array_api_version__ == "2021.12"
476476

477477
with pytest.warns(UserWarning):
478-
assert a.__array_namespace__(api_version="2024.12") is array_api_strict
479-
assert array_api_strict.__array_api_version__ == "2024.12"
478+
assert a.__array_namespace__(api_version="2025.12") is array_api_strict
479+
assert array_api_strict.__array_api_version__ == "2025.12"
480+
480481

481482
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
482-
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2025.12"))
483+
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2026.12"))
483484

484485
def test_iter():
485486
pytest.raises(TypeError, lambda: iter(asarray(3)))

array_api_strict/tests/test_data_type_functions.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,11 @@ def test_result_type_py_scalars(api_version):
8080
with pytest.raises(TypeError):
8181
result_type(int16, 3)
8282
else:
83-
with pytest.warns(UserWarning):
84-
set_array_api_strict_flags(api_version=api_version)
83+
set_array_api_strict_flags(api_version=api_version)
8584

86-
assert result_type(int8, 3) == int8
87-
assert result_type(uint8, 3) == uint8
88-
assert result_type(float64, 3) == float64
85+
assert result_type(int8, 3) == int8
86+
assert result_type(uint8, 3) == uint8
87+
assert result_type(float64, 3) == float64
8988

90-
with pytest.raises(TypeError):
91-
result_type(int64, True)
89+
with pytest.raises(TypeError):
90+
result_type(int64, True)

array_api_strict/tests/test_elementwise_functions.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from pytest import raises as assert_raises
44
from numpy.testing import suppress_warnings
55

6-
import pytest
76

87
from .. import asarray, _elementwise_functions
98
from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
@@ -134,8 +133,7 @@ def _array_vals(dtypes):
134133
yield asarray(1., dtype=d)
135134

136135
# Use the latest version of the standard so all functions are included
137-
with pytest.warns(UserWarning):
138-
set_array_api_strict_flags(api_version="2024.12")
136+
set_array_api_strict_flags(api_version="2024.12")
139137

140138
for func_name, types in elementwise_function_input_types.items():
141139
dtypes = _dtype_categories[types]
@@ -171,8 +169,7 @@ def _array_vals():
171169
yield asarray(1.0, dtype=d)
172170

173171
# Use the latest version of the standard so all functions are included
174-
with pytest.warns(UserWarning):
175-
set_array_api_strict_flags(api_version="2024.12")
172+
set_array_api_strict_flags(api_version="2024.12")
176173

177174
for x in _array_vals():
178175
for func_name, types in elementwise_function_input_types.items():
@@ -216,8 +213,7 @@ def test_scalars():
216213
# arguments, and reject (scalar, scalar) arguments.
217214

218215
# Use the latest version of the standard so that scalars are actually allowed
219-
with pytest.warns(UserWarning):
220-
set_array_api_strict_flags(api_version="2024.12")
216+
set_array_api_strict_flags(api_version="2024.12")
221217

222218
def _array_vals():
223219
for d in _integer_dtypes:

array_api_strict/tests/test_searching_functions.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,22 @@
33
import array_api_strict as xp
44

55
from array_api_strict import ArrayAPIStrictFlags
6-
from array_api_strict._flags import draft_version
76

87

98
def test_where_with_scalars():
109
x = xp.asarray([1, 2, 3, 1])
1110

1211
# Versions up to and including 2023.12 don't support scalar arguments
13-
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
14-
xp.where(x == 1, 42, 44)
12+
with ArrayAPIStrictFlags(api_version='2023.12'):
13+
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
14+
xp.where(x == 1, 42, 44)
1515

1616
# Versions after 2023.12 support scalar arguments
17-
with (pytest.warns(
18-
UserWarning,
19-
match="The 2024.12 version of the array API specification is in draft status"
20-
),
21-
ArrayAPIStrictFlags(api_version=draft_version),
22-
):
23-
x_where = xp.where(x == 1, xp.asarray(42), 44)
24-
25-
expected = xp.asarray([42, 44, 44, 42])
26-
assert xp.all(x_where == expected)
27-
28-
# The spec does not allow both x1 and x2 to be scalars
29-
with pytest.raises(ValueError, match="One of"):
30-
xp.where(x == 1, 42, 44)
17+
x_where = xp.where(x == 1, xp.asarray(42), 44)
18+
19+
expected = xp.asarray([42, 44, 44, 42])
20+
assert xp.all(x_where == expected)
21+
22+
# The spec does not allow both x1 and x2 to be scalars
23+
with pytest.raises(ValueError, match="One of"):
24+
xp.where(x == 1, 42, 44)

array_api_strict/tests/test_statistical_functions.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import cmath
22
import pytest
33

4-
from .._flags import set_array_api_strict_flags
4+
from .._flags import set_array_api_strict_flags, ArrayAPIStrictFlags
55

66
import array_api_strict as xp
77

@@ -44,12 +44,10 @@ def test_sum_prod_trace_2023_12(func_name):
4444
def test_mean_complex():
4545
a = xp.asarray([1j, 2j, 3j])
4646

47-
set_array_api_strict_flags(api_version='2023.12')
48-
with pytest.raises(TypeError):
49-
xp.mean(a)
47+
with ArrayAPIStrictFlags(api_version='2023.12'):
48+
with pytest.raises(TypeError):
49+
xp.mean(a)
5050

51-
with pytest.warns(UserWarning):
52-
set_array_api_strict_flags(api_version='2024.12')
5351
m = xp.mean(a)
5452
assert cmath.isclose(complex(m), 2j)
5553

0 commit comments

Comments
 (0)