diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 9416e38..53669d1 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from typing import Optional, Tuple, Union, Any - from ._typing import PyCapsule, Device, Dtype + from ._typing import PyCapsule, Dtype import numpy.typing as npt import numpy as np @@ -586,15 +586,24 @@ def __dlpack__( if copy is not _default: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") - # Going to wait for upstream numpy support - if max_version not in [_default, None]: - raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") - if dl_device not in [_default, None]: - raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") - if copy not in [_default, None]: - raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") + if np.__version__[0] < '2.1': + if max_version not in [_default, None]: + raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") + if dl_device not in [_default, None]: + raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") - return self._array.__dlpack__(stream=stream) + return self._array.__dlpack__(stream=stream) + else: + kwargs = {'stream': stream} + if max_version is not _default: + kwargs['max_version'] = max_version + if dl_device is not _default: + kwargs['dl_device'] = dl_device + if copy is not _default: + kwargs['copy'] = copy + return self._array.__dlpack__(**kwargs) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: """ diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index c7781d7..96fd31e 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -460,18 +460,27 @@ def dlpack_2023_12(api_version): a.__dlpack__() - exception = NotImplementedError if api_version >= '2023.12' else ValueError - pytest.raises(exception, lambda: - a.__dlpack__(dl_device=CPU_DEVICE)) - pytest.raises(exception, lambda: - a.__dlpack__(dl_device=None)) - pytest.raises(exception, lambda: - a.__dlpack__(max_version=(1, 0))) - pytest.raises(exception, lambda: - a.__dlpack__(max_version=None)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=False)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=True)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=None)) + if np.__version__ < '2.1': + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=CPU_DEVICE)) + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=None)) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=(1, 0))) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=None)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=False)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=True)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=None)) + else: + a.__dlpack__(dl_device=CPU_DEVICE) + a.__dlpack__(dl_device=None) + a.__dlpack__(max_version=(1, 0)) + a.__dlpack__(max_version=None) + a.__dlpack__(copy=False) + a.__dlpack__(copy=True) + a.__dlpack__(copy=None)