Skip to content

Commit 973f1a2

Browse files
committed
Guard scalar arguments with API_VERSION>=2024.12
1 parent 866cedb commit 973f1a2

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

array_api_strict/_helpers.py

+9
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,20 @@
22
"""
33
import numpy as np
44

5+
from ._flags import get_array_api_strict_flags
6+
7+
58
_py_scalars = (bool, int, float, complex)
69

10+
711
def _maybe_normalize_py_scalars(x1, x2):
812
from ._array_object import Array
913

14+
flags = get_array_api_strict_flags()
15+
if flags["api_version"] < "2024.12": # XXX: string comparison for versions
16+
# scalars will fail at the call site
17+
return x1, x2
18+
1019
if isinstance(x1, _py_scalars):
1120
if isinstance(x2, _py_scalars):
1221
raise TypeError(f"Two scalars not allowed, {type(x1) = } and {type(x2) =}")

array_api_strict/tests/test_elementwise_functions.py

+4
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ def _sample_scalar(category):
218218
else:
219219
raise ValueError(f'Unknown {category = }')
220220

221+
# Use the latest version of the standard so that scalars are actually allowed
222+
with pytest.warns(UserWarning):
223+
set_array_api_strict_flags(api_version="2024.12")
224+
221225
for func_name, types in elementwise_function_input_types.items():
222226
dtypes = _dtype_categories[types]
223227
func = getattr(_elementwise_functions, func_name)

0 commit comments

Comments
 (0)