From f8a6a9eb6c66799557d3af182500dbb4fbef1af7 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 10 Dec 2024 22:34:55 +0000 Subject: [PATCH] BUG: `from_dlpack`: fix default device --- array_api_strict/_creation_functions.py | 2 ++ array_api_strict/tests/test_creation_functions.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index e506bcc..460dba9 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -226,6 +226,8 @@ def from_dlpack( # Going to wait for upstream numpy support if device is not _default: _check_device(device) + else: + device = None if copy not in [_default, None]: raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index c93a08a..fc4e3cb 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -236,3 +236,10 @@ def from_dlpack_2023_12(api_version): pytest.raises(exception, lambda: from_dlpack(capsule, copy=False)) pytest.raises(exception, lambda: from_dlpack(capsule, copy=True)) pytest.raises(exception, lambda: from_dlpack(capsule, copy=None)) + + +def test_from_dlpack_default_device(): + x = asarray([1, 2, 3]) + y = from_dlpack(x) + z = from_dlpack(np.asarray([1, 2, 3])) + assert x.device == y.device == z.device == CPU_DEVICE