diff --git a/extensions/nanoarrow_device/src/nanoarrow/nanoarrow_device.c b/extensions/nanoarrow_device/src/nanoarrow/nanoarrow_device.c index 0c76d961e..3896283fa 100644 --- a/extensions/nanoarrow_device/src/nanoarrow/nanoarrow_device.c +++ b/extensions/nanoarrow_device/src/nanoarrow/nanoarrow_device.c @@ -115,7 +115,7 @@ struct ArrowDevice* ArrowDeviceCpu(void) { void ArrowDeviceInitCpu(struct ArrowDevice* device) { device->device_type = ARROW_DEVICE_CPU; - device->device_id = 0; + device->device_id = -1; device->array_init = NULL; device->array_move = NULL; device->buffer_init = &ArrowDeviceCpuBufferInit; @@ -135,7 +135,7 @@ struct ArrowDevice* ArrowDeviceCuda(ArrowDeviceType device_type, int64_t device_ #endif struct ArrowDevice* ArrowDeviceResolve(ArrowDeviceType device_type, int64_t device_id) { - if (device_type == ARROW_DEVICE_CPU && device_id == 0) { + if (device_type == ARROW_DEVICE_CPU) { return ArrowDeviceCpu(); } diff --git a/extensions/nanoarrow_device/src/nanoarrow/nanoarrow_device_test.cc b/extensions/nanoarrow_device/src/nanoarrow/nanoarrow_device_test.cc index f437b3698..8ed39a24b 100644 --- a/extensions/nanoarrow_device/src/nanoarrow/nanoarrow_device_test.cc +++ b/extensions/nanoarrow_device/src/nanoarrow/nanoarrow_device_test.cc @@ -28,7 +28,7 @@ TEST(NanoarrowDevice, CheckRuntime) { TEST(NanoarrowDevice, CpuDevice) { struct ArrowDevice* cpu = ArrowDeviceCpu(); EXPECT_EQ(cpu->device_type, ARROW_DEVICE_CPU); - EXPECT_EQ(cpu->device_id, 0); + EXPECT_EQ(cpu->device_id, -1); EXPECT_EQ(cpu, ArrowDeviceCpu()); void* sync_event = nullptr; diff --git a/python/src/nanoarrow/_lib.pyx b/python/src/nanoarrow/_lib.pyx index a83e029c0..d9bda9d91 100644 --- a/python/src/nanoarrow/_lib.pyx +++ b/python/src/nanoarrow/_lib.pyx @@ -34,7 +34,6 @@ generally have better autocomplete + documentation available to IDEs). from libc.stdint cimport uintptr_t, uint8_t, int64_t from libc.string cimport memcpy from libc.stdio cimport snprintf -from libc.errno cimport ENOMEM from cpython.bytes cimport PyBytes_FromStringAndSize from cpython.pycapsule cimport PyCapsule_New, PyCapsule_GetPointer, PyCapsule_IsValid from cpython cimport ( @@ -51,6 +50,7 @@ from cpython.ref cimport Py_INCREF, Py_DECREF from nanoarrow_c cimport * from nanoarrow_device_c cimport * +from enum import Enum from sys import byteorder as sys_byteorder from struct import unpack_from, iter_unpack, calcsize, Struct from nanoarrow import _repr_utils @@ -183,7 +183,7 @@ cdef void c_array_shallow_copy(object base, const ArrowArray* c_array, c_array_out.release = arrow_array_release -cdef object alloc_c_array_shallow_copy(object base, const ArrowArray* c_array) noexcept: +cdef object alloc_c_array_shallow_copy(object base, const ArrowArray* c_array): """Make a shallow copy of an ArrowArray To more safely implement export of an ArrowArray whose address may be @@ -198,6 +198,30 @@ cdef object alloc_c_array_shallow_copy(object base, const ArrowArray* c_array) n return array_capsule +cdef void c_device_array_shallow_copy(object base, const ArrowDeviceArray* c_array, + ArrowDeviceArray* c_array_out) noexcept: + # shallow copy + memcpy(c_array_out, c_array, sizeof(ArrowDeviceArray)) + c_array_out.array.release = NULL + c_array_out.array.private_data = NULL + + # track original base + c_array_out.array.private_data = base + Py_INCREF(base) + c_array_out.array.release = arrow_array_release + + +cdef object alloc_c_device_array_shallow_copy(object base, const ArrowDeviceArray* c_array): + """Make a shallow copy of an ArrowDeviceArray + + See :func:`arrow_c_array_shallow_copy()` + """ + cdef ArrowDeviceArray* c_array_out + array_capsule = alloc_c_device_array(&c_array_out) + c_device_array_shallow_copy(base, c_array, c_array_out) + return array_capsule + + cdef void pycapsule_buffer_deleter(object stream_capsule) noexcept: cdef ArrowBuffer* buffer = PyCapsule_GetPointer( stream_capsule, 'nanoarrow_buffer' @@ -207,11 +231,12 @@ cdef void pycapsule_buffer_deleter(object stream_capsule) noexcept: ArrowFree(buffer) -cdef object alloc_c_buffer(ArrowBuffer** c_buffer) noexcept: +cdef object alloc_c_buffer(ArrowBuffer** c_buffer): c_buffer[0] = ArrowMalloc(sizeof(ArrowBuffer)) ArrowBufferInit(c_buffer[0]) return PyCapsule_New(c_buffer[0], 'nanoarrow_buffer', &pycapsule_buffer_deleter) + cdef void c_deallocate_pybuffer(ArrowBufferAllocator* allocator, uint8_t* ptr, int64_t size) noexcept with gil: cdef Py_buffer* buffer = allocator.private_data PyBuffer_Release(buffer) @@ -499,8 +524,34 @@ cdef class CArrowTimeUnit: NANO = NANOARROW_TIME_UNIT_NANO +class DeviceType(Enum): + """ + An enumerator providing access to the device constant values + defined in the Arrow C Device interface. Unlike the other enum + accessors, this Python Enum is defined in Cython so that we can use + the bulit-in functionality to do better printing of device identifiers + for classes defined in Cython. Unlike the other enums, users don't + typically need to specify these (but would probably like them printed + nicely). + """ -cdef class CDevice: + CPU = ARROW_DEVICE_CPU + CUDA = ARROW_DEVICE_CUDA + CUDA_HOST = ARROW_DEVICE_CUDA_HOST + OPENCL = ARROW_DEVICE_OPENCL + VULKAN = ARROW_DEVICE_VULKAN + METAL = ARROW_DEVICE_METAL + VPI = ARROW_DEVICE_VPI + ROCM = ARROW_DEVICE_ROCM + ROCM_HOST = ARROW_DEVICE_ROCM_HOST + EXT_DEV = ARROW_DEVICE_EXT_DEV + CUDA_MANAGED = ARROW_DEVICE_CUDA_MANAGED + ONEAPI = ARROW_DEVICE_ONEAPI + WEBGPU = ARROW_DEVICE_WEBGPU + HEXAGON = ARROW_DEVICE_HEXAGON + + +cdef class Device: """ArrowDevice wrapper The ArrowDevice structure is a nanoarrow internal struct (i.e., @@ -530,6 +581,10 @@ cdef class CDevice: @property def device_type(self): + return DeviceType(self._ptr.device_type) + + @property + def device_type_id(self): return self._ptr.device_type @property @@ -537,16 +592,16 @@ cdef class CDevice: return self._ptr.device_id @staticmethod - def resolve(ArrowDeviceType device_type, int64_t device_id): - if device_type == ARROW_DEVICE_CPU: - return CDEVICE_CPU + def resolve(device_type, int64_t device_id): + if int(device_type) == ARROW_DEVICE_CPU: + return DEVICE_CPU else: raise ValueError(f"Device not found for type {device_type}/{device_id}") # Cache the CPU device # The CPU device is statically allocated (so base is None) -CDEVICE_CPU = CDevice(None, ArrowDeviceCpu()) +DEVICE_CPU = Device(None, ArrowDeviceCpu()) cdef class CSchema: @@ -1027,6 +1082,8 @@ cdef class CArray: cdef object _base cdef ArrowArray* _ptr cdef CSchema _schema + cdef ArrowDeviceType _device_type + cdef int _device_id @staticmethod def allocate(CSchema schema): @@ -1038,6 +1095,12 @@ cdef class CArray: self._base = base self._ptr = addr self._schema = schema + self._device_type = ARROW_DEVICE_CPU + self._device_id = 0 + + cdef _set_device(self, ArrowDeviceType device_type, int64_t device_id): + self._device_type = device_type + self._device_id = device_id @staticmethod def _import_from_c_capsule(schema_capsule, array_capsule): @@ -1095,7 +1158,9 @@ cdef class CArray: c_array_out.offset = c_array_out.offset + start c_array_out.length = stop - start - return CArray(base, c_array_out, self._schema) + cdef CArray out = CArray(base, c_array_out, self._schema) + out._set_device(self._device_type, self._device_id) + return out def __arrow_c_array__(self, requested_schema=None): """ @@ -1115,6 +1180,11 @@ cdef class CArray: """ self._assert_valid() + if self._device_type != ARROW_DEVICE_CPU: + raise ValueError( + "Can't invoke __arrow_c_array__ on non-CPU array " + f"with device_type {self._device_type}") + if requested_schema is not None: raise NotImplementedError("requested_schema") @@ -1137,10 +1207,26 @@ cdef class CArray: if self._ptr.release == NULL: raise RuntimeError("CArray is released") + def view(self): + device = Device.resolve(self._device_type, self._device_id) + return CArrayView.from_array(self, device) + @property def schema(self): return self._schema + @property + def device_type(self): + return DeviceType(self._device_type) + + @property + def device_type_id(self): + return self._device_type + + @property + def device_id(self): + return self._device_id + @property def length(self): self._assert_valid() @@ -1175,7 +1261,13 @@ cdef class CArray: self._assert_valid() if i < 0 or i >= self._ptr.n_children: raise IndexError(f"{i} out of range [0, {self._ptr.n_children})") - return CArray(self._base, self._ptr.children[i], self._schema.child(i)) + cdef CArray out = CArray( + self._base, + self._ptr.children[i], + self._schema.child(i) + ) + out._set_device(self._device_type, self._device_id) + return out @property def children(self): @@ -1185,8 +1277,11 @@ cdef class CArray: @property def dictionary(self): self._assert_valid() + cdef CArray out if self._ptr.dictionary != NULL: - return CArray(self, self._ptr.dictionary, self._schema.dictionary) + out = CArray(self, self._ptr.dictionary, self._schema.dictionary) + out._set_device(self._device_type, self._device_id) + return out else: return None @@ -1206,18 +1301,18 @@ cdef class CArrayView: cdef object _base cdef object _array_base cdef ArrowArrayView* _ptr - cdef CDevice _device + cdef Device _device def __cinit__(self, object base, uintptr_t addr): self._base = base self._ptr = addr - self._device = CDEVICE_CPU + self._device = DEVICE_CPU - def _set_array(self, CArray array, CDevice device=CDEVICE_CPU): + def _set_array(self, CArray array, Device device=DEVICE_CPU): cdef Error error = Error() cdef int code - if device is CDEVICE_CPU: + if device is DEVICE_CPU: code = ArrowArrayViewSetArray(self._ptr, array._ptr, &error.c_error) else: code = ArrowArrayViewSetArrayMinimal(self._ptr, array._ptr, &error.c_error) @@ -1349,7 +1444,7 @@ cdef class CArrayView: return CArrayView(base, c_array_view) @staticmethod - def from_array(CArray array, CDevice device=CDEVICE_CPU): + def from_array(CArray array, Device device=DEVICE_CPU): out = CArrayView.from_schema(array._schema) return out._set_array(array, device) @@ -1396,7 +1491,7 @@ cdef class CBufferView: cdef object _base cdef ArrowBufferView _ptr cdef ArrowType _data_type - cdef CDevice _device + cdef Device _device cdef Py_ssize_t _element_size_bits cdef Py_ssize_t _shape cdef Py_ssize_t _strides @@ -1404,7 +1499,7 @@ cdef class CBufferView: def __cinit__(self, object base, uintptr_t addr, int64_t size_bytes, ArrowType data_type, - Py_ssize_t element_size_bits, CDevice device): + Py_ssize_t element_size_bits, Device device): self._base = base self._ptr.data.data = addr self._ptr.size_bytes = size_bytes @@ -1564,7 +1659,7 @@ cdef class CBufferView: self._do_releasebuffer(buffer) cdef _do_getbuffer(self, Py_buffer *buffer, int flags): - if self._device is not CDEVICE_CPU: + if self._device is not DEVICE_CPU: raise RuntimeError("CBufferView is not a CPU buffer") if flags & PyBUF_WRITABLE: @@ -1606,7 +1701,7 @@ cdef class CBuffer: cdef ArrowType _data_type cdef int _element_size_bits cdef char _format[32] - cdef CDevice _device + cdef Device _device cdef CBufferView _view cdef int _get_buffer_count @@ -1615,7 +1710,7 @@ cdef class CBuffer: self._ptr = NULL self._data_type = NANOARROW_TYPE_BINARY self._element_size_bits = 0 - self._device = CDEVICE_CPU + self._device = DEVICE_CPU # Set initial format to "B" (Cython makes this hard) self._format[0] = 66 self._format[1] = 0 @@ -1652,7 +1747,7 @@ cdef class CBuffer: cdef CBuffer out = CBuffer() out._base = alloc_c_buffer(&out._ptr) out._set_format(c_buffer_set_pybuffer(obj, &out._ptr)) - out._device = CDEVICE_CPU + out._device = DEVICE_CPU out._populate_view() return out @@ -2330,8 +2425,16 @@ cdef class CDeviceArray: self._ptr = addr self._schema = schema + @property + def schema(self): + return self._schema + @property def device_type(self): + return DeviceType(self._ptr.device_type) + + @property + def device_type_id(self): return self._ptr.device_type @property @@ -2340,7 +2443,53 @@ cdef class CDeviceArray: @property def array(self): - return CArray(self, &self._ptr.array, self._schema) + # TODO: We lose access to the sync_event here, so we probably need to + # synchronize (or propagate it, or somehow prevent data access downstream) + cdef CArray array = CArray(self, &self._ptr.array, self._schema) + array._set_device(self._ptr.device_type, self._ptr.device_id) + return array + + def view(self): + return self.array.view() + + def __arrow_c_array__(self, requested_schema=None): + return self.array.__arrow_c_array__(requested_schema=requested_schema) + + def __arrow_c_device_array__(self, requested_schema=None): + if requested_schema is not None: + raise NotImplementedError("requested_schema") + + # TODO: evaluate whether we need to synchronize here or whether we should + # move device arrays instead of shallow-copying them + device_array_capsule = alloc_c_device_array_shallow_copy(self._base, self._ptr) + return self._schema.__arrow_c_schema__(), device_array_capsule + + @staticmethod + def _import_from_c_capsule(schema_capsule, device_array_capsule): + """ + Import from an ArrowSchema and ArrowArray PyCapsule tuple. + + Parameters + ---------- + schema_capsule : PyCapsule + A valid PyCapsule with name 'arrow_schema' containing an + ArrowSchema pointer. + device_array_capsule : PyCapsule + A valid PyCapsule with name 'arrow_device_array' containing an + ArrowDeviceArray pointer. + """ + cdef: + CSchema out_schema + CDeviceArray out + + out_schema = CSchema._import_from_c_capsule(schema_capsule) + out = CDeviceArray( + device_array_capsule, + PyCapsule_GetPointer(device_array_capsule, 'arrow_device_array'), + out_schema + ) + + return out def __repr__(self): return _repr_utils.device_array_repr(self) diff --git a/python/src/nanoarrow/_repr_utils.py b/python/src/nanoarrow/_repr_utils.py index 99b11fde0..3209a3413 100644 --- a/python/src/nanoarrow/_repr_utils.py +++ b/python/src/nanoarrow/_repr_utils.py @@ -169,7 +169,7 @@ def buffer_view_repr(buffer_view, max_char_width=80): prefix = f"{buffer_view.data_type}" prefix += f"[{buffer_view.size_bytes} b]" - if buffer_view.device.device_type == 1: + if buffer_view.device.device_type_id == 1: return ( prefix + " " @@ -232,7 +232,10 @@ def device_array_repr(device_array): class_label = make_class_label(device_array, module="nanoarrow.device") title_line = f"<{class_label}>" - device_type = f"- device_type: {device_array.device_type}" + device_type = ( + f"- device_type: {device_array.device_type.name} " + f"<{device_array.device_type_id}>" + ) device_id = f"- device_id: {device_array.device_id}" array = f"- array: {array_repr(device_array.array, indent=2)}" return "\n".join((title_line, device_type, device_id, array)) @@ -242,6 +245,6 @@ def device_repr(device): class_label = make_class_label(device, module="nanoarrow.device") title_line = f"<{class_label}>" - device_type = f"- device_type: {device.device_type}" + device_type = f"- device_type: {device.device_type.name} <{device.device_type_id}>" device_id = f"- device_id: {device.device_id}" return "\n".join([title_line, device_type, device_id]) diff --git a/python/src/nanoarrow/array.py b/python/src/nanoarrow/array.py index 78756e150..af2e3cd47 100644 --- a/python/src/nanoarrow/array.py +++ b/python/src/nanoarrow/array.py @@ -18,13 +18,7 @@ from functools import cached_property from typing import Iterable, Tuple -from nanoarrow._lib import ( - CDEVICE_CPU, - CArray, - CBuffer, - CDevice, - CMaterializedArrayStream, -) +from nanoarrow._lib import DEVICE_CPU, CArray, CBuffer, CMaterializedArrayStream, Device from nanoarrow.c_lib import c_array, c_array_stream, c_array_view from nanoarrow.iterator import iter_py, iter_tuples from nanoarrow.schema import Schema @@ -65,7 +59,7 @@ def __init__(self): self._device = None @property - def device(self) -> CDevice: + def device(self) -> Device: return self._device @property @@ -121,7 +115,7 @@ class Array: :func:`c_array_stream`. schema : schema-like, optional An optional schema, passed to :func:`c_array_stream`. - device : CDevice, optional + device : Device, optional The device associated with the buffers held by this Array. Defaults to the CPU device. @@ -138,11 +132,11 @@ class Array: def __init__(self, obj, schema=None, device=None) -> None: if device is None: - self._device = CDEVICE_CPU - elif isinstance(device, CDevice): + self._device = DEVICE_CPU + elif isinstance(device, Device): self._device = device else: - raise TypeError("device must be CDevice") + raise TypeError("device must be Device") if isinstance(obj, CMaterializedArrayStream) and schema is None: self._data = obj @@ -164,7 +158,7 @@ def _assert_one_chunk(self, op): raise ValueError(f"Can't {op} with non-contiguous Array") def _assert_cpu(self, op): - if self._device != CDEVICE_CPU: + if self._device != DEVICE_CPU: raise ValueError(f"Can't {op} with Array on non-CPU device") def __arrow_c_stream__(self, requested_schema=None): @@ -186,7 +180,7 @@ def __arrow_c_array__(self, requested_schema=None): self._assert_one_chunk("export ArrowArray") @property - def device(self) -> CDevice: + def device(self) -> Device: """Get the device on which the buffers for this array are allocated. Examples @@ -195,9 +189,9 @@ def device(self) -> CDevice: >>> import nanoarrow as na >>> array = na.Array([1, 2, 3], na.int32()) >>> array.device - - - device_type: 1 - - device_id: 0 + + - device_type: CPU <1> + - device_id: -1 """ return self._device diff --git a/python/src/nanoarrow/c_lib.py b/python/src/nanoarrow/c_lib.py index 68a53b782..0acc0a9d0 100644 --- a/python/src/nanoarrow/c_lib.py +++ b/python/src/nanoarrow/c_lib.py @@ -427,7 +427,7 @@ def c_array_view(obj, schema=None) -> CArrayView: if isinstance(obj, CArrayView) and schema is None: return obj - return CArrayView.from_array(c_array(obj, schema)) + return c_array(obj, schema).view() def c_buffer(obj, schema=None) -> CBuffer: diff --git a/python/src/nanoarrow/device.py b/python/src/nanoarrow/device.py index 2bc5d408c..7bf0dcea0 100644 --- a/python/src/nanoarrow/device.py +++ b/python/src/nanoarrow/device.py @@ -15,23 +15,32 @@ # specific language governing permissions and limitations # under the License. -from nanoarrow._lib import CDEVICE_CPU, CDevice, CDeviceArray -from nanoarrow.c_lib import c_array +from nanoarrow._lib import DEVICE_CPU, CDeviceArray, Device, DeviceType # noqa: F401 +from nanoarrow.c_lib import c_array, c_schema def cpu(): - return CDEVICE_CPU + return DEVICE_CPU def resolve(device_type, device_id): - return CDevice.resolve(device_type, device_id) + return Device.resolve(device_type, device_id) -def c_device_array(obj): - if isinstance(obj, CDeviceArray): +def c_device_array(obj, schema=None): + if schema is not None: + schema = c_schema(schema) + + if isinstance(obj, CDeviceArray) and schema is None: return obj - # Only CPU for now - cpu_array = c_array(obj) + if hasattr(obj, "__arrow_c_device_array__"): + schema_capsule = None if schema is None else schema.__arrow_c_schema__() + schema_capsule, device_array_capsule = obj.__arrow_c_device_array__( + requested_schema=schema_capsule + ) + return CDeviceArray._import_from_c_capsule(schema_capsule, device_array_capsule) + # Attempt to create a CPU array and wrap it + cpu_array = c_array(obj, schema=schema) return cpu()._array_init(cpu_array._addr(), cpu_array.schema) diff --git a/python/src/nanoarrow/nanoarrow_device_c.pxd b/python/src/nanoarrow/nanoarrow_device_c.pxd index f2a65a905..5c8a12ef8 100644 --- a/python/src/nanoarrow/nanoarrow_device_c.pxd +++ b/python/src/nanoarrow/nanoarrow_device_c.pxd @@ -26,7 +26,17 @@ cdef extern from "nanoarrow_device.h" nogil: int32_t ARROW_DEVICE_CPU int32_t ARROW_DEVICE_CUDA int32_t ARROW_DEVICE_CUDA_HOST + int32_t ARROW_DEVICE_OPENCL + int32_t ARROW_DEVICE_VULKAN int32_t ARROW_DEVICE_METAL + int32_t ARROW_DEVICE_VPI + int32_t ARROW_DEVICE_ROCM + int32_t ARROW_DEVICE_ROCM_HOST + int32_t ARROW_DEVICE_EXT_DEV + int32_t ARROW_DEVICE_CUDA_MANAGED + int32_t ARROW_DEVICE_ONEAPI + int32_t ARROW_DEVICE_WEBGPU + int32_t ARROW_DEVICE_HEXAGON struct ArrowDeviceArray: ArrowArray array diff --git a/python/tests/test_array.py b/python/tests/test_array.py index fe590e607..ee88d20d0 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -31,7 +31,7 @@ def test_array_construct(): array2 = na.Array(array._data) assert array2._data is array._data - with pytest.raises(TypeError, match="device must be CDevice"): + with pytest.raises(TypeError, match="device must be Device"): na.Array([], na.int32(), device=1234) with pytest.raises(NotImplementedError): diff --git a/python/tests/test_c_array.py b/python/tests/test_c_array.py index 75ab2aa7b..1536d159e 100644 --- a/python/tests/test_c_array.py +++ b/python/tests/test_c_array.py @@ -39,6 +39,8 @@ def test_c_array_from_c_array(): assert c_array_from_c_array.length == c_array.length assert c_array_from_c_array.buffers == c_array.buffers + assert list(c_array.view().buffer(1)) == [1, 2, 3] + def test_c_array_from_capsule_protocol(): class CArrayWrapper: @@ -54,6 +56,8 @@ def __arrow_c_array__(self, *args, **kwargs): assert c_array_from_protocol.length == c_array.length assert c_array_from_protocol.buffers == c_array.buffers + assert list(c_array_from_protocol.view().buffer(1)) == [1, 2, 3] + def test_c_array_from_old_pyarrow(): # Simulate a pyarrow Array with no __arrow_c_array__ @@ -73,6 +77,8 @@ def _export_to_c(self, *args): assert c_array.length == 3 assert c_array.schema.format == "i" + assert list(c_array.view().buffer(1)) == [1, 2, 3] + # Make sure that this heuristic won't result in trying to import # something else that has an _export_to_c method with pytest.raises(TypeError, match="Can't convert object of type DataType"): @@ -97,6 +103,8 @@ def test_c_array_from_bare_capsule(): assert c_array_from_capsule.length == c_array.length assert c_array_from_capsule.buffers == c_array.buffers + assert list(c_array_from_capsule.view().buffer(1)) == [1, 2, 3] + def test_c_array_type_not_supported(): with pytest.raises(TypeError, match="Can't convert object of type NoneType"): diff --git a/python/tests/test_device.py b/python/tests/test_device.py index 93028816e..1158337a2 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -17,26 +17,66 @@ import pytest +import nanoarrow as na from nanoarrow import device -pa = pytest.importorskip("pyarrow") - def test_cpu_device(): cpu = device.cpu() - assert cpu.device_type == 1 - assert cpu.device_id == 0 - assert "device_type: 1" in repr(cpu) + assert cpu.device_type_id == 1 + assert cpu.device_type == device.DeviceType.CPU + assert cpu.device_id == -1 + assert "device_type: CPU <1>" in repr(cpu) + + cpu2 = device.resolve(1, 0) + assert cpu2 is cpu + - cpu = device.resolve(1, 0) - assert cpu.device_type == 1 +def test_c_device_array(): + # Unrecognized arguments should be passed to c_array() to generate CPU array + darray = device.c_device_array([1, 2, 3], na.int32()) - pa_array = pa.array([1, 2, 3]) + assert darray.device_type_id == 1 + assert darray.device_type == device.DeviceType.CPU + assert darray.device_id == -1 + assert "device_type: CPU <1>" in repr(darray) + + assert darray.schema.format == "i" - darray = device.c_device_array(pa_array) - assert darray.device_type == 1 - assert darray.device_id == 0 assert darray.array.length == 3 - assert "device_type: 1" in repr(darray) + assert darray.array.device_type == device.cpu().device_type + assert darray.array.device_id == device.cpu().device_id + + darray_view = darray.view() + assert darray_view.length == 3 + assert list(darray_view.buffer(1)) == [1, 2, 3] + # A CDeviceArray should be returned as is assert device.c_device_array(darray) is darray + + # A CPU device array should be able to export to a regular array + array = na.c_array(darray) + assert array.schema.format == "i" + assert array.buffers == darray.array.buffers + + +def test_c_device_array_protocol(): + # Wrapper to prevent c_device_array() from returning early when it detects the + # input is already a CDeviceArray + class CDeviceArrayWrapper: + def __init__(self, obj): + self.obj = obj + + def __arrow_c_device_array__(self, requested_schema=None): + return self.obj.__arrow_c_device_array__(requested_schema=requested_schema) + + darray = device.c_device_array([1, 2, 3], na.int32()) + wrapper = CDeviceArrayWrapper(darray) + + darray2 = device.c_device_array(wrapper) + assert darray2.schema.format == "i" + assert darray2.array.length == 3 + assert darray2.array.buffers == darray.array.buffers + + with pytest.raises(NotImplementedError): + device.c_device_array(wrapper, na.int64())