Skip to content

Commit 71b01c1

Browse files
leofangrgommers
andauthored
Support copy and device keywords in from_dlpack (#741)
* support copy in from_dlpack * specify copy stream * allow 3-way copy arg to align all constructors * update to reflect the discussions * clairfy a bit and fix typos * sync the copy docs * clarify what's 'on CPU' * try to make linter happy * remove namespace leak clause, clean up, and add an example * make linter happy * fix Sphinx complaint about Enum * add/update v2023-specific notes on device * remove a note on kDLCPU --------- Co-authored-by: Ralf Gommers <[email protected]>
1 parent 474ec2b commit 71b01c1

File tree

2 files changed

+89
-8
lines changed

2 files changed

+89
-8
lines changed

src/array_api_stubs/_draft/array_object.py

+46-5
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ def __dlpack__(
293293
*,
294294
stream: Optional[Union[int, Any]] = None,
295295
max_version: Optional[tuple[int, int]] = None,
296+
dl_device: Optional[tuple[Enum, int]] = None,
297+
copy: Optional[bool] = None,
296298
) -> PyCapsule:
297299
"""
298300
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
@@ -324,6 +326,12 @@ def __dlpack__(
324326
- ``> 2``: stream number represented as a Python integer.
325327
- Using ``1`` and ``2`` is not supported.
326328
329+
.. note::
330+
When ``dl_device`` is provided explicitly, ``stream`` must be a valid
331+
construct for the specified device type. In particular, when ``kDLCPU``
332+
is in use, ``stream`` must be ``None`` and a synchronization must be
333+
performed to ensure data safety.
334+
327335
.. admonition:: Tip
328336
:class: important
329337
@@ -333,12 +341,40 @@ def __dlpack__(
333341
not want to think about stream handling at all, potentially at the
334342
cost of more synchronizations than necessary.
335343
max_version: Optional[tuple[int, int]]
336-
The maximum DLPack version that the *consumer* (i.e., the caller of
344+
the maximum DLPack version that the *consumer* (i.e., the caller of
337345
``__dlpack__``) supports, in the form of a 2-tuple ``(major, minor)``.
338346
This method may return a capsule of version ``max_version`` (recommended
339347
if it does support that), or of a different version.
340348
This means the consumer must verify the version even when
341349
`max_version` is passed.
350+
dl_device: Optional[tuple[enum.Enum, int]]
351+
the DLPack device type. Default is ``None``, meaning the exported capsule
352+
should be on the same device as ``self`` is. When specified, the format
353+
must be a 2-tuple, following that of the return value of :meth:`array.__dlpack_device__`.
354+
If the device type cannot be handled by the producer, this function must
355+
raise ``BufferError``.
356+
357+
The v2023.12 standard only mandates that a compliant library should offer a way for
358+
``__dlpack__`` to return a capsule referencing an array whose underlying memory is
359+
accessible to the Python interpreter (represented by the ``kDLCPU`` enumerator in DLPack).
360+
If a copy must be made to enable this support but ``copy`` is set to ``False``, the
361+
function must raise ``ValueError``.
362+
363+
Other device kinds will be considered for standardization in a future version of this
364+
API standard.
365+
copy: Optional[bool]
366+
boolean indicating whether or not to copy the input. If ``True``, the
367+
function must always copy (performed by the producer). If ``False``, the
368+
function must never copy, and raise a ``BufferError`` in case a copy is
369+
deemed necessary (e.g. if a cross-device data movement is requested, and
370+
it is not possible without a copy). If ``None``, the function must reuse
371+
the existing memory buffer if possible and copy otherwise. Default: ``None``.
372+
373+
When a copy happens, the ``DLPACK_FLAG_BITMASK_IS_COPIED`` flag must be set.
374+
375+
.. note::
376+
If a copy happens, and if the consumer-provided ``stream`` and ``dl_device``
377+
can be understood by the producer, the copy must be performed over ``stream``.
342378
343379
Returns
344380
-------
@@ -394,22 +430,25 @@ def __dlpack__(
394430
# here to tell users that the consumer's max_version is too
395431
# old to allow the data exchange to happen.
396432
397-
And this logic for the consumer in ``from_dlpack``:
433+
And this logic for the consumer in :func:`~array_api.from_dlpack`:
398434
399435
.. code:: python
400436
401437
try:
402-
x.__dlpack__(max_version=(1, 0))
438+
x.__dlpack__(max_version=(1, 0), ...)
403439
# if it succeeds, store info from the capsule named "dltensor_versioned",
404440
# and need to set the name to "used_dltensor_versioned" when we're done
405441
except TypeError:
406-
x.__dlpack__()
442+
x.__dlpack__(...)
443+
444+
This logic is also applicable to handling of the new ``dl_device`` and ``copy``
445+
keywords.
407446
408447
.. versionchanged:: 2022.12
409448
Added BufferError.
410449
411450
.. versionchanged:: 2023.12
412-
Added the ``max_version`` keyword.
451+
Added the ``max_version``, ``dl_device``, and ``copy`` keywords.
413452
"""
414453

415454
def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
@@ -436,6 +475,8 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
436475
METAL = 8
437476
VPI = 9
438477
ROCM = 10
478+
CUDA_MANAGED = 13
479+
ONE_API = 14
439480
"""
440481

441482
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:

src/array_api_stubs/_draft/creation_functions.py

+43-3
Original file line numberDiff line numberDiff line change
@@ -214,19 +214,36 @@ def eye(
214214
"""
215215

216216

217-
def from_dlpack(x: object, /) -> array:
217+
def from_dlpack(
218+
x: object,
219+
/,
220+
*,
221+
device: Optional[device] = None,
222+
copy: Optional[bool] = None,
223+
) -> array:
218224
"""
219225
Returns a new array containing the data from another (array) object with a ``__dlpack__`` method.
220226
221227
Parameters
222228
----------
223229
x: object
224230
input (array) object.
231+
device: Optional[device]
232+
device on which to place the created array. If ``device`` is ``None`` and ``x`` supports DLPack, the output array must be on the same device as ``x``. Default: ``None``.
233+
234+
The v2023.12 standard only mandates that a compliant library should offer a way for ``from_dlpack`` to return an array
235+
whose underlying memory is accessible to the Python interpreter, when the corresponding ``device`` is provided. If the
236+
array library does not support such cases at all, the function must raise ``BufferError``. If a copy must be made to
237+
enable this support but ``copy`` is set to ``False``, the function must raise ``ValueError``.
238+
239+
Other device kinds will be considered for standardization in a future version of this API standard.
240+
copy: Optional[bool]
241+
boolean indicating whether or not to copy the input. If ``True``, the function must always copy. If ``False``, the function must never copy, and raise ``BufferError`` in case a copy is deemed necessary (e.g. if a cross-device data movement is requested, and it is not possible without a copy). If ``None``, the function must reuse the existing memory buffer if possible and copy otherwise. Default: ``None``.
225242
226243
Returns
227244
-------
228245
out: array
229-
an array containing the data in `x`.
246+
an array containing the data in ``x``.
230247
231248
.. admonition:: Note
232249
:class: note
@@ -238,19 +255,42 @@ def from_dlpack(x: object, /) -> array:
238255
BufferError
239256
The ``__dlpack__`` and ``__dlpack_device__`` methods on the input array
240257
may raise ``BufferError`` when the data cannot be exported as DLPack
241-
(e.g., incompatible dtype or strides). It may also raise other errors
258+
(e.g., incompatible dtype, strides, or device). It may also raise other errors
242259
when export fails for other reasons (e.g., not enough memory available
243260
to materialize the data). ``from_dlpack`` must propagate such
244261
exceptions.
245262
AttributeError
246263
If the ``__dlpack__`` and ``__dlpack_device__`` methods are not present
247264
on the input array. This may happen for libraries that are never able
248265
to export their data with DLPack.
266+
ValueError
267+
If data exchange is possible via an explicit copy but ``copy`` is set to ``False``.
249268
250269
Notes
251270
-----
252271
See :meth:`array.__dlpack__` for implementation suggestions for `from_dlpack` in
253272
order to handle DLPack versioning correctly.
273+
274+
A way to move data from two array libraries to the same device (assumed supported by both libraries) in
275+
a library-agnostic fashion is illustrated below:
276+
277+
.. code:: python
278+
279+
def func(x, y):
280+
xp_x = x.__array_namespace__()
281+
xp_y = y.__array_namespace__()
282+
283+
# Other functions than `from_dlpack` only work if both arrays are from the same library. So if
284+
# `y` is from a different one than `x`, let's convert `y` into an array of the same type as `x`:
285+
if not xp_x == xp_y:
286+
y = xp_x.from_dlpack(y, copy=True, device=x.device)
287+
288+
# From now on use `xp_x.xxxxx` functions, as both arrays are from the library `xp_x`
289+
...
290+
291+
292+
.. versionchanged:: 2023.12
293+
Added device and copy support.
254294
"""
255295

256296

0 commit comments

Comments
 (0)