Skip to content

Commit 476f45a

Browse files
committed
Fix dlpack test and use NumpyVersion
1 parent ca387b2 commit 476f45a

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

array_api_strict/_array_object.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def __dlpack__(
586586
if copy is not _default:
587587
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")
588588

589-
if np.__version__[0] < '2.1':
589+
if np.lib.NumpyVersion(np.__version__) < '2.1.0':
590590
if max_version not in [_default, None]:
591591
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
592592
if dl_device not in [_default, None]:

array_api_strict/tests/test_array_object.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def test_iter():
448448
pytest.raises(TypeError, lambda: iter(ones((3, 3))))
449449

450450
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
451-
def dlpack_2023_12(api_version):
451+
def test_dlpack_2023_12(api_version):
452452
if api_version == '2021.12':
453453
with pytest.warns(UserWarning):
454454
set_array_api_strict_flags(api_version=api_version)
@@ -459,25 +459,35 @@ def dlpack_2023_12(api_version):
459459
# Never an error
460460
a.__dlpack__()
461461

462-
463-
if np.__version__ < '2.1':
464-
exception = NotImplementedError if api_version >= '2023.12' else ValueError
465-
pytest.raises(exception, lambda:
466-
a.__dlpack__(dl_device=CPU_DEVICE))
467-
pytest.raises(exception, lambda:
462+
if api_version < '2023.12':
463+
pytest.raises(ValueError, lambda:
464+
a.__dlpack__(dl_device=a.__dlpack_device__()))
465+
pytest.raises(ValueError, lambda:
468466
a.__dlpack__(dl_device=None))
469-
pytest.raises(exception, lambda:
467+
pytest.raises(ValueError, lambda:
470468
a.__dlpack__(max_version=(1, 0)))
471-
pytest.raises(exception, lambda:
469+
pytest.raises(ValueError, lambda:
472470
a.__dlpack__(max_version=None))
473-
pytest.raises(exception, lambda:
474-
a.__dlpack__(copy=False))
475-
pytest.raises(exception, lambda:
476-
a.__dlpack__(copy=True))
477-
pytest.raises(exception, lambda:
478-
a.__dlpack__(copy=None))
471+
pytest.raises(ValueError, lambda:
472+
a.__dlpack__(copy=False))
473+
pytest.raises(ValueError, lambda:
474+
a.__dlpack__(copy=True))
475+
pytest.raises(ValueError, lambda:
476+
a.__dlpack__(copy=None))
477+
elif np.lib.NumpyVersion(np.__version__) < '2.1.0':
478+
pytest.raises(NotImplementedError, lambda:
479+
a.__dlpack__(dl_device=CPU_DEVICE))
480+
a.__dlpack__(dl_device=None)
481+
pytest.raises(NotImplementedError, lambda:
482+
a.__dlpack__(max_version=(1, 0)))
483+
a.__dlpack__(max_version=None)
484+
pytest.raises(NotImplementedError, lambda:
485+
a.__dlpack__(copy=False))
486+
pytest.raises(NotImplementedError, lambda:
487+
a.__dlpack__(copy=True))
488+
a.__dlpack__(copy=None)
479489
else:
480-
a.__dlpack__(dl_device=CPU_DEVICE)
490+
a.__dlpack__(dl_device=a.__dlpack_device__())
481491
a.__dlpack__(dl_device=None)
482492
a.__dlpack__(max_version=(1, 0))
483493
a.__dlpack__(max_version=None)

0 commit comments

Comments
 (0)