Skip to content

Commit 7c3b948

Browse files
committed
Only check api_version in __array_namespace__ when a wrapper is used
1 parent 2ff0634 commit 7c3b948

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

array_api_compat/common/_helpers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,10 @@ def your_function(x, y):
248248
namespaces = set()
249249
for x in xs:
250250
if is_numpy_array(x):
251-
_check_api_version(api_version)
252251
from .. import numpy as numpy_namespace
253252
import numpy as np
254253
if use_compat is True:
254+
_check_api_version(api_version)
255255
namespaces.add(numpy_namespace)
256256
elif use_compat is False:
257257
namespaces.add(np)
@@ -263,32 +263,32 @@ def your_function(x, y):
263263
else:
264264
namespaces.add(numpy_namespace)
265265
elif is_cupy_array(x):
266-
_check_api_version(api_version)
267266
if _use_compat:
267+
_check_api_version(api_version)
268268
from .. import cupy as cupy_namespace
269269
namespaces.add(cupy_namespace)
270270
else:
271271
import cupy as cp
272272
namespaces.add(cp)
273273
elif is_torch_array(x):
274-
_check_api_version(api_version)
275274
if _use_compat:
275+
_check_api_version(api_version)
276276
from .. import torch as torch_namespace
277277
namespaces.add(torch_namespace)
278278
else:
279279
import torch
280280
namespaces.add(torch)
281281
elif is_dask_array(x):
282-
_check_api_version(api_version)
283282
if _use_compat:
283+
_check_api_version(api_version)
284284
from ..dask import array as dask_namespace
285285
namespaces.add(dask_namespace)
286286
else:
287287
import dask.array as da
288288
namespaces.add(da)
289289
elif is_jax_array(x):
290-
_check_api_version(api_version)
291290
if use_compat is True:
291+
_check_api_version(api_version)
292292
raise ValueError("JAX does not have an array-api-compat wrapper")
293293
elif use_compat is False:
294294
import jax.numpy as jnp

tests/test_array_namespace.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ._helpers import import_, all_libraries, wrapped_libraries
1313

1414
@pytest.mark.parametrize("use_compat", [True, False, None])
15-
@pytest.mark.parametrize("api_version", [None, "2021.12"])
15+
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"])
1616
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
1717
def test_array_namespace(library, api_version, use_compat):
1818
xp = import_(library)
@@ -69,14 +69,14 @@ def test_array_namespace_errors_torch():
6969
pytest.raises(TypeError, lambda: array_namespace(x, y))
7070

7171
def test_api_version():
72-
x = np.asarray([1, 2])
73-
np_ = import_("numpy", wrapper=True)
74-
assert array_namespace(x, api_version="2022.12") == np_
75-
assert array_namespace(x, api_version=None) == np_
76-
assert array_namespace(x) == np_
72+
x = torch.asarray([1, 2])
73+
torch_ = import_("torch", wrapper=True)
74+
assert array_namespace(x, api_version="2022.12") == torch_
75+
assert array_namespace(x, api_version=None) == torch_
76+
assert array_namespace(x) == torch_
7777
# Should issue a warning
7878
with warnings.catch_warnings(record=True) as w:
79-
assert array_namespace(x, api_version="2021.12") == np_
79+
assert array_namespace(x, api_version="2021.12") == torch_
8080
assert len(w) == 1
8181
assert "2021.12" in str(w[0].message)
8282

0 commit comments

Comments
 (0)