From 4c80f6d8ccfcf2c9dbe5bdc2d3812ed1c9114525 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 16:10:18 -0700 Subject: [PATCH 1/4] Allow __dlpack__ to work with newer versions of NumPy --- array_api_strict/_array_object.py | 19 ++++++++++--------- array_api_strict/tests/test_array_object.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 9416e38..fa9cce8 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -586,15 +586,16 @@ def __dlpack__( if copy is not _default: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") - # Going to wait for upstream numpy support - if max_version not in [_default, None]: - raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") - if dl_device not in [_default, None]: - raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") - if copy not in [_default, None]: - raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") - - return self._array.__dlpack__(stream=stream) + if np.__version__ < '2.1': + if max_version not in [_default, None]: + raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") + if dl_device not in [_default, None]: + raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") + + return self._array.__dlpack__(stream=stream) + return self._array.__dlpack__(stream=stream, max_version=max_version, dl_device=dl_device, copy=copy) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: """ diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index c7781d7..aea24da 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -460,7 +460,7 @@ def dlpack_2023_12(api_version): a.__dlpack__() - exception = NotImplementedError if api_version >= '2023.12' else ValueError + exception = NotImplementedError if api_version >= '2023.12' and np.__version__ < '2.1' else ValueError pytest.raises(exception, lambda: a.__dlpack__(dl_device=CPU_DEVICE)) pytest.raises(exception, lambda: From fa71e9e5f0b81f9413da4f9581908e67a2971b07 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 16:14:38 -0700 Subject: [PATCH 2/4] Fix version check --- array_api_strict/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index fa9cce8..03993ab 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -586,7 +586,7 @@ def __dlpack__( if copy is not _default: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") - if np.__version__ < '2.1': + if np.__version__[0] < '2.1': if max_version not in [_default, None]: raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") if dl_device not in [_default, None]: From 67d9667ba1ae883c55235043e4e13a80c23ff4e0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 16:14:43 -0700 Subject: [PATCH 3/4] Remove unused import --- array_api_strict/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 03993ab..faea86c 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from typing import Optional, Tuple, Union, Any - from ._typing import PyCapsule, Device, Dtype + from ._typing import PyCapsule, Dtype import numpy.typing as npt import numpy as np From 93201f134463a2fa56bebecc349dd67d4dc3d49f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 16:23:16 -0700 Subject: [PATCH 4/4] Fix passing of keyword arguments in __dlpack__ for NumPy 2.1 --- array_api_strict/_array_object.py | 10 +++++- array_api_strict/tests/test_array_object.py | 39 +++++++++++++-------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index faea86c..53669d1 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -595,7 +595,15 @@ def __dlpack__( raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") return self._array.__dlpack__(stream=stream) - return self._array.__dlpack__(stream=stream, max_version=max_version, dl_device=dl_device, copy=copy) + else: + kwargs = {'stream': stream} + if max_version is not _default: + kwargs['max_version'] = max_version + if dl_device is not _default: + kwargs['dl_device'] = dl_device + if copy is not _default: + kwargs['copy'] = copy + return self._array.__dlpack__(**kwargs) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: """ diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index aea24da..96fd31e 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -460,18 +460,27 @@ def dlpack_2023_12(api_version): a.__dlpack__() - exception = NotImplementedError if api_version >= '2023.12' and np.__version__ < '2.1' else ValueError - pytest.raises(exception, lambda: - a.__dlpack__(dl_device=CPU_DEVICE)) - pytest.raises(exception, lambda: - a.__dlpack__(dl_device=None)) - pytest.raises(exception, lambda: - a.__dlpack__(max_version=(1, 0))) - pytest.raises(exception, lambda: - a.__dlpack__(max_version=None)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=False)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=True)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=None)) + if np.__version__ < '2.1': + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=CPU_DEVICE)) + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=None)) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=(1, 0))) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=None)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=False)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=True)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=None)) + else: + a.__dlpack__(dl_device=CPU_DEVICE) + a.__dlpack__(dl_device=None) + a.__dlpack__(max_version=(1, 0)) + a.__dlpack__(max_version=None) + a.__dlpack__(copy=False) + a.__dlpack__(copy=True) + a.__dlpack__(copy=None)