Skip to content

Commit a24aa23

Browse files
authored
refactor(python): Split device functionality into its own module (#548)
Similar to previous PRs, this PR moves the device class definition into its own module. Both arrays and buffers are device-aware, so splitting the device definition is a prerequisite.
1 parent 395aed4 commit a24aa23

File tree

6 files changed

+177
-98
lines changed

6 files changed

+177
-98
lines changed

python/setup.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,21 @@ def get_version(pkg_path):
9595

9696
setup(
9797
ext_modules=[
98+
Extension(
99+
name="nanoarrow._device",
100+
include_dirs=["src/nanoarrow", "vendor"],
101+
language="c",
102+
sources=[
103+
"src/nanoarrow/_device.pyx",
104+
"vendor/nanoarrow.c",
105+
"vendor/nanoarrow_device.c",
106+
],
107+
extra_compile_args=extra_compile_args,
108+
extra_link_args=extra_link_args,
109+
define_macros=extra_define_macros,
110+
library_dirs=library_dirs,
111+
libraries=libraries,
112+
),
98113
Extension(
99114
name="nanoarrow._types",
100115
include_dirs=["src/nanoarrow", "vendor"],

python/src/nanoarrow/_device.pxd

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# cython: language_level = 3
19+
20+
from nanoarrow_device_c cimport ArrowDevice
21+
22+
cdef class Device:
23+
cdef object _base
24+
cdef ArrowDevice* _ptr

python/src/nanoarrow/_device.pyx

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# cython: language_level = 3
19+
20+
from libc.stdint cimport uintptr_t, int64_t
21+
22+
from nanoarrow_device_c cimport (
23+
ARROW_DEVICE_CPU,
24+
ARROW_DEVICE_CUDA,
25+
ARROW_DEVICE_CUDA_HOST,
26+
ARROW_DEVICE_OPENCL,
27+
ARROW_DEVICE_VULKAN,
28+
ARROW_DEVICE_METAL,
29+
ARROW_DEVICE_VPI,
30+
ARROW_DEVICE_ROCM,
31+
ARROW_DEVICE_ROCM_HOST,
32+
ARROW_DEVICE_EXT_DEV,
33+
ARROW_DEVICE_CUDA_MANAGED,
34+
ARROW_DEVICE_ONEAPI,
35+
ARROW_DEVICE_WEBGPU,
36+
ARROW_DEVICE_HEXAGON,
37+
ArrowDevice,
38+
ArrowDeviceCpu,
39+
ArrowDeviceResolve
40+
)
41+
42+
from enum import Enum
43+
44+
from nanoarrow import _repr_utils
45+
46+
47+
class DeviceType(Enum):
48+
"""
49+
An enumerator providing access to the device constant values
50+
defined in the Arrow C Device interface. Unlike the other enum
51+
accessors, this Python Enum is defined in Cython so that we can use
52+
the bulit-in functionality to do better printing of device identifiers
53+
for classes defined in Cython. Unlike the other enums, users don't
54+
typically need to specify these (but would probably like them printed
55+
nicely).
56+
"""
57+
58+
CPU = ARROW_DEVICE_CPU
59+
CUDA = ARROW_DEVICE_CUDA
60+
CUDA_HOST = ARROW_DEVICE_CUDA_HOST
61+
OPENCL = ARROW_DEVICE_OPENCL
62+
VULKAN = ARROW_DEVICE_VULKAN
63+
METAL = ARROW_DEVICE_METAL
64+
VPI = ARROW_DEVICE_VPI
65+
ROCM = ARROW_DEVICE_ROCM
66+
ROCM_HOST = ARROW_DEVICE_ROCM_HOST
67+
EXT_DEV = ARROW_DEVICE_EXT_DEV
68+
CUDA_MANAGED = ARROW_DEVICE_CUDA_MANAGED
69+
ONEAPI = ARROW_DEVICE_ONEAPI
70+
WEBGPU = ARROW_DEVICE_WEBGPU
71+
HEXAGON = ARROW_DEVICE_HEXAGON
72+
73+
74+
cdef class Device:
75+
"""ArrowDevice wrapper
76+
77+
The ArrowDevice structure is a nanoarrow internal struct (i.e.,
78+
not ABI stable) that contains callbacks for device operations
79+
beyond its type and identifier (e.g., copy buffers to or from
80+
a device).
81+
"""
82+
83+
def __cinit__(self, object base, uintptr_t addr):
84+
self._base = base,
85+
self._ptr = <ArrowDevice*>addr
86+
87+
def __repr__(self):
88+
return _repr_utils.device_repr(self)
89+
90+
@property
91+
def device_type(self):
92+
return DeviceType(self._ptr.device_type)
93+
94+
@property
95+
def device_type_id(self):
96+
return self._ptr.device_type
97+
98+
@property
99+
def device_id(self):
100+
return self._ptr.device_id
101+
102+
@staticmethod
103+
def resolve(device_type, int64_t device_id):
104+
if int(device_type) == ARROW_DEVICE_CPU:
105+
return DEVICE_CPU
106+
107+
cdef ArrowDevice* c_device = ArrowDeviceResolve(device_type, device_id)
108+
if c_device == NULL:
109+
raise ValueError(f"Device not found for type {device_type}/{device_id}")
110+
111+
return Device(None, <uintptr_t>c_device)
112+
113+
114+
# Cache the CPU device
115+
# The CPU device is statically allocated (so base is None)
116+
DEVICE_CPU = Device(None, <uintptr_t>ArrowDeviceCpu())

python/src/nanoarrow/_lib.pyx

Lines changed: 17 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ from nanoarrow_c cimport *
5353
from nanoarrow_device_c cimport *
5454
from nanoarrow_dlpack cimport *
5555

56-
from enum import Enum
57-
from struct import unpack_from, iter_unpack, calcsize, Struct
58-
from nanoarrow import _repr_utils
56+
from nanoarrow._device cimport Device
5957

6058
from nanoarrow cimport _types
6159
from nanoarrow._utils cimport (
@@ -71,6 +69,11 @@ from nanoarrow._utils cimport (
7169
Error
7270
)
7371

72+
from struct import unpack_from, iter_unpack, calcsize, Struct
73+
from nanoarrow import _repr_utils
74+
from nanoarrow._device import DEVICE_CPU, DeviceType
75+
76+
7477
cdef void pycapsule_dlpack_deleter(object dltensor) noexcept:
7578
cdef DLManagedTensor* dlm_tensor
7679

@@ -176,91 +179,6 @@ cdef class CArrowTimeUnit:
176179
NANO = NANOARROW_TIME_UNIT_NANO
177180

178181

179-
class DeviceType(Enum):
180-
"""
181-
An enumerator providing access to the device constant values
182-
defined in the Arrow C Device interface. Unlike the other enum
183-
accessors, this Python Enum is defined in Cython so that we can use
184-
the bulit-in functionality to do better printing of device identifiers
185-
for classes defined in Cython. Unlike the other enums, users don't
186-
typically need to specify these (but would probably like them printed
187-
nicely).
188-
"""
189-
190-
CPU = ARROW_DEVICE_CPU
191-
CUDA = ARROW_DEVICE_CUDA
192-
CUDA_HOST = ARROW_DEVICE_CUDA_HOST
193-
OPENCL = ARROW_DEVICE_OPENCL
194-
VULKAN = ARROW_DEVICE_VULKAN
195-
METAL = ARROW_DEVICE_METAL
196-
VPI = ARROW_DEVICE_VPI
197-
ROCM = ARROW_DEVICE_ROCM
198-
ROCM_HOST = ARROW_DEVICE_ROCM_HOST
199-
EXT_DEV = ARROW_DEVICE_EXT_DEV
200-
CUDA_MANAGED = ARROW_DEVICE_CUDA_MANAGED
201-
ONEAPI = ARROW_DEVICE_ONEAPI
202-
WEBGPU = ARROW_DEVICE_WEBGPU
203-
HEXAGON = ARROW_DEVICE_HEXAGON
204-
205-
206-
cdef class Device:
207-
"""ArrowDevice wrapper
208-
209-
The ArrowDevice structure is a nanoarrow internal struct (i.e.,
210-
not ABI stable) that contains callbacks for device operations
211-
beyond its type and identifier (e.g., copy buffers to or from
212-
a device).
213-
"""
214-
215-
cdef object _base
216-
cdef ArrowDevice* _ptr
217-
218-
def __cinit__(self, object base, uintptr_t addr):
219-
self._base = base,
220-
self._ptr = <ArrowDevice*>addr
221-
222-
def _array_init(self, uintptr_t array_addr, CSchema schema):
223-
cdef ArrowArray* array_ptr = <ArrowArray*>array_addr
224-
cdef ArrowDeviceArray* device_array_ptr
225-
cdef void* sync_event = NULL
226-
holder = alloc_c_device_array(&device_array_ptr)
227-
cdef int code = ArrowDeviceArrayInit(self._ptr, device_array_ptr, array_ptr, sync_event)
228-
Error.raise_error_not_ok("ArrowDevice::init_array", code)
229-
230-
return CDeviceArray(holder, <uintptr_t>device_array_ptr, schema)
231-
232-
def __repr__(self):
233-
return _repr_utils.device_repr(self)
234-
235-
@property
236-
def device_type(self):
237-
return DeviceType(self._ptr.device_type)
238-
239-
@property
240-
def device_type_id(self):
241-
return self._ptr.device_type
242-
243-
@property
244-
def device_id(self):
245-
return self._ptr.device_id
246-
247-
@staticmethod
248-
def resolve(device_type, int64_t device_id):
249-
if int(device_type) == ARROW_DEVICE_CPU:
250-
return DEVICE_CPU
251-
252-
cdef ArrowDevice* c_device = ArrowDeviceResolve(device_type, device_id)
253-
if c_device == NULL:
254-
raise ValueError(f"Device not found for type {device_type}/{device_id}")
255-
256-
return Device(None, <uintptr_t>c_device)
257-
258-
259-
# Cache the CPU device
260-
# The CPU device is statically allocated (so base is None)
261-
DEVICE_CPU = Device(None, <uintptr_t>ArrowDeviceCpu())
262-
263-
264182
cdef class CSchema:
265183
"""Low-level ArrowSchema wrapper
266184
@@ -2808,6 +2726,17 @@ cdef class CDeviceArray:
28082726
self._ptr = <ArrowDeviceArray*>addr
28092727
self._schema = schema
28102728

2729+
@staticmethod
2730+
def _init_from_array(Device device, uintptr_t array_addr, CSchema schema):
2731+
cdef ArrowArray* array_ptr = <ArrowArray*>array_addr
2732+
cdef ArrowDeviceArray* device_array_ptr
2733+
cdef void* sync_event = NULL
2734+
holder = alloc_c_device_array(&device_array_ptr)
2735+
cdef int code = ArrowDeviceArrayInit(device._ptr, device_array_ptr, array_ptr, sync_event)
2736+
Error.raise_error_not_ok("ArrowDeviceArrayInit", code)
2737+
2738+
return CDeviceArray(holder, <uintptr_t>device_array_ptr, schema)
2739+
28112740
@property
28122741
def schema(self):
28132742
return self._schema

python/src/nanoarrow/array.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,8 @@
1919
from functools import cached_property
2020
from typing import Iterable, Tuple
2121

22-
from nanoarrow._lib import (
23-
DEVICE_CPU,
24-
CArray,
25-
CArrayView,
26-
CBuffer,
27-
CMaterializedArrayStream,
28-
Device,
29-
)
22+
from nanoarrow._device import DEVICE_CPU, Device
23+
from nanoarrow._lib import CArray, CArrayView, CBuffer, CMaterializedArrayStream
3024
from nanoarrow.c_array import c_array, c_array_view
3125
from nanoarrow.c_array_stream import c_array_stream
3226
from nanoarrow.c_schema import c_schema

python/src/nanoarrow/device.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from nanoarrow._lib import DEVICE_CPU, CDeviceArray, Device, DeviceType # noqa: F401
18+
from nanoarrow._device import DEVICE_CPU, Device, DeviceType # noqa: F401
19+
from nanoarrow._lib import CDeviceArray
1920
from nanoarrow.c_array import c_array
2021
from nanoarrow.c_schema import c_schema
2122

@@ -44,4 +45,4 @@ def c_device_array(obj, schema=None):
4445

4546
# Attempt to create a CPU array and wrap it
4647
cpu_array = c_array(obj, schema=schema)
47-
return cpu()._array_init(cpu_array._addr(), cpu_array.schema)
48+
return CDeviceArray._init_from_array(cpu(), cpu_array._addr(), cpu_array.schema)

0 commit comments

Comments
 (0)