Skip to content

Commit dbdd09c

Browse files
committed
Fix test failures
Adds explicit tests for api_version.
1 parent f62898d commit dbdd09c

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

Diff for: array_api_compat/common/_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -325,4 +325,4 @@ def size(x):
325325
"to_device",
326326
]
327327

328-
_all_ignore = ['sys', 'math', 'inspect']
328+
_all_ignore = ['sys', 'math', 'inspect', 'warnings']

Diff for: tests/test_array_namespace.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import subprocess
22
import sys
3+
import warnings
34

45
import numpy as np
56
import pytest
@@ -57,13 +58,24 @@ def test_array_namespace_errors():
5758
pytest.raises(TypeError, lambda: array_namespace((x, x)))
5859
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
5960

60-
6161
def test_array_namespace_errors_torch():
6262
y = torch.asarray([1, 2])
6363
x = np.asarray([1, 2])
6464
pytest.raises(TypeError, lambda: array_namespace(x, y))
65-
pytest.raises(ValueError, lambda: array_namespace(x, api_version="2022.12"))
6665

66+
def test_api_version():
67+
x = np.asarray([1, 2])
68+
np_ = import_("numpy", wrapper=True)
69+
assert array_namespace(x, api_version="2022.12") == np_
70+
assert array_namespace(x, api_version=None) == np_
71+
assert array_namespace(x) == np_
72+
# Should issue a warning
73+
with warnings.catch_warnings(record=True) as w:
74+
assert array_namespace(x, api_version="2021.12") == np_
75+
assert len(w) == 1
76+
assert "2021.12" in str(w[0].message)
77+
78+
pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
6779

6880
def test_get_namespace():
6981
# Backwards compatible wrapper

0 commit comments

Comments
 (0)