-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy path_helpers.py
955 lines (770 loc) · 27.9 KB
/
_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
"""
Various helper functions which are not part of the spec.
Functions which start with an underscore are for internal use only but helpers
that are in __all__ are intended as additional helper functions for use by end
users of the compat library.
"""
from __future__ import annotations
import sys
import math
import inspect
import warnings
from typing import Optional, Union, Any
from ._typing import Array, Device, Namespace
def _is_jax_zero_gradient_array(x: object) -> bool:
"""Return True if `x` is a zero-gradient array.
These arrays are a design quirk of Jax that may one day be removed.
See https://github.com/google/jax/issues/20620.
"""
if 'numpy' not in sys.modules or 'jax' not in sys.modules:
return False
import numpy as np
import jax
return isinstance(x, np.ndarray) and x.dtype == jax.float0
def is_numpy_array(x: object) -> bool:
"""
Return True if `x` is a NumPy array.
This function does not import NumPy if it has not already been imported
and is therefore cheap to use.
This also returns True for `ndarray` subclasses and NumPy scalar objects.
See Also
--------
array_namespace
is_array_api_obj
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing NumPy if it isn't already
if 'numpy' not in sys.modules:
return False
import numpy as np
# TODO: Should we reject ndarray subclasses?
return (isinstance(x, (np.ndarray, np.generic))
and not _is_jax_zero_gradient_array(x))
def is_cupy_array(x: object) -> bool:
"""
Return True if `x` is a CuPy array.
This function does not import CuPy if it has not already been imported
and is therefore cheap to use.
This also returns True for `cupy.ndarray` subclasses and CuPy scalar objects.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing CuPy if it isn't already
if 'cupy' not in sys.modules:
return False
import cupy as cp
# TODO: Should we reject ndarray subclasses?
return isinstance(x, cp.ndarray)
def is_torch_array(x: object) -> bool:
"""
Return True if `x` is a PyTorch tensor.
This function does not import PyTorch if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing torch if it isn't already
if 'torch' not in sys.modules:
return False
import torch
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)
def is_ndonnx_array(x: object) -> bool:
"""
Return True if `x` is a ndonnx Array.
This function does not import ndonnx if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing torch if it isn't already
if 'ndonnx' not in sys.modules:
return False
import ndonnx as ndx
return isinstance(x, ndx.Array)
def is_dask_array(x: object) -> bool:
"""
Return True if `x` is a dask.array Array.
This function does not import dask if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
return False
import dask.array
return isinstance(x, dask.array.Array)
def is_jax_array(x: object) -> bool:
"""
Return True if `x` is a JAX array.
This function does not import JAX if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_pydata_sparse_array
"""
# Avoid importing jax if it isn't already
if 'jax' not in sys.modules:
return False
import jax
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
def is_pydata_sparse_array(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.
This function does not import `sparse` if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
"""
# Avoid importing jax if it isn't already
if 'sparse' not in sys.modules:
return False
import sparse
# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)
def is_array_api_obj(x: object) -> bool:
"""
Return True if `x` is an array API compatible array object.
See Also
--------
array_namespace
is_numpy_array
is_cupy_array
is_torch_array
is_ndonnx_array
is_dask_array
is_jax_array
"""
return is_numpy_array(x) \
or is_cupy_array(x) \
or is_torch_array(x) \
or is_dask_array(x) \
or is_jax_array(x) \
or is_pydata_sparse_array(x) \
or hasattr(x, '__array_namespace__')
def _compat_module_name() -> str:
assert __name__.endswith('.common._helpers')
return __name__.removesuffix('.common._helpers')
def is_numpy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a NumPy namespace.
This includes both NumPy itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
def is_cupy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a CuPy namespace.
This includes both CuPy itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
def is_torch_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.
This includes both PyTorch itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
def is_ndonnx_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is an NDONNX namespace.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'ndonnx'
def is_dask_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a Dask namespace.
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
def is_jax_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a JAX namespace.
This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
older versions of JAX.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
def is_pydata_sparse_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a pydata/sparse namespace.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'sparse'
def is_array_api_strict_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is an array-api-strict namespace.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
"""
return xp.__name__ == 'array_api_strict'
def _check_api_version(api_version: str) -> None:
if api_version in ['2021.12', '2022.12', '2023.12']:
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12")
elif api_version is not None and api_version not in ['2021.12', '2022.12',
'2023.12', '2024.12']:
raise ValueError("Only the 2024.12 version of the array API specification is currently supported")
def array_namespace(
*xs: Union[Array, bool, int, float, complex, None],
api_version: Optional[str] = None,
use_compat: Optional[bool] = None,
) -> Namespace:
"""
Get the array API compatible namespace for the arrays `xs`.
Parameters
----------
xs: arrays
one or more arrays. xs can also be Python scalars (bool, int, float,
complex, or None), which are ignored.
api_version: str
The newest version of the spec that you need support for (currently
the compat library wrapped APIs support v2024.12).
use_compat: bool or None
If None (the default), the native namespace will be returned if it is
already array API compatible, otherwise a compat wrapper is used. If
True, the compat library wrapped library will be returned. If False,
the native library namespace is returned.
Returns
-------
out: namespace
The array API compatible namespace corresponding to the arrays in `xs`.
Raises
------
TypeError
If `xs` contains arrays from different array libraries or contains a
non-array.
Typical usage is to pass the arguments of a function to
`array_namespace()` at the top of a function to get the corresponding
array API namespace:
.. code:: python
def your_function(x, y):
xp = array_api_compat.array_namespace(x, y)
# Now use xp as the array library namespace
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
Wrapped array namespaces can also be imported directly. For example,
`array_namespace(np.array(...))` will return `array_api_compat.numpy`.
This function will also work for any array library not wrapped by
array-api-compat if it explicitly defines `__array_namespace__
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html>`__
(the wrapped namespace is always preferred if it exists).
See Also
--------
is_array_api_obj
is_numpy_array
is_cupy_array
is_torch_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
if use_compat not in [None, True, False]:
raise ValueError("use_compat must be None, True, or False")
_use_compat = use_compat in [None, True]
namespaces = set()
for x in xs:
if is_numpy_array(x):
from .. import numpy as numpy_namespace
import numpy as np
if use_compat is True:
_check_api_version(api_version)
namespaces.add(numpy_namespace)
elif use_compat is False:
namespaces.add(np)
else:
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
# compatible.
namespaces.add(numpy_namespace)
elif is_cupy_array(x):
if _use_compat:
_check_api_version(api_version)
from .. import cupy as cupy_namespace
namespaces.add(cupy_namespace)
else:
import cupy as cp
namespaces.add(cp)
elif is_torch_array(x):
if _use_compat:
_check_api_version(api_version)
from .. import torch as torch_namespace
namespaces.add(torch_namespace)
else:
import torch
namespaces.add(torch)
elif is_dask_array(x):
if _use_compat:
_check_api_version(api_version)
from ..dask import array as dask_namespace
namespaces.add(dask_namespace)
else:
import dask.array as da
namespaces.add(da)
elif is_jax_array(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("JAX does not have an array-api-compat wrapper")
elif use_compat is False:
import jax.numpy as jnp
else:
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
# For older JAX versions, it is available via jax.experimental.array_api.
import jax.numpy
if hasattr(jax.numpy, "__array_api_version__"):
jnp = jax.numpy
else:
import jax.experimental.array_api as jnp
namespaces.add(jnp)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("`sparse` does not have an array-api-compat wrapper")
else:
import sparse
# `sparse` is already an array namespace. We do not have a wrapper
# submodule for it.
namespaces.add(sparse)
elif hasattr(x, '__array_namespace__'):
if use_compat is True:
raise ValueError("The given array does not have an array-api-compat wrapper")
namespaces.add(x.__array_namespace__(api_version=api_version))
elif isinstance(x, (bool, int, float, complex, type(None))):
continue
else:
# TODO: Support Python scalars?
raise TypeError(f"{type(x).__name__} is not a supported array type")
if not namespaces:
raise TypeError("Unrecognized array input")
if len(namespaces) != 1:
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
xp, = namespaces
return xp
# backwards compatibility alias
get_namespace = array_namespace
def _check_device(bare_xp, device):
"""
Validate dummy device on device-less array backends.
Notes
-----
This function is also invoked by CuPy, which does have multiple devices
if there are multiple GPUs available.
However, CuPy multi-device support is currently impossible
without using the global device or a context manager:
https://github.com/data-apis/array-api-compat/pull/293
"""
if bare_xp is sys.modules.get('numpy'):
if device not in ("cpu", None):
raise ValueError(f"Unsupported device for NumPy: {device!r}")
elif bare_xp is sys.modules.get('dask.array'):
if device not in ("cpu", _DASK_DEVICE, None):
raise ValueError(f"Unsupported device for Dask: {device!r}")
# Placeholder object to represent the dask device
# when the array backend is not the CPU.
# (since it is not easy to tell which device a dask array is on)
class _dask_device:
def __repr__(self):
return "DASK_DEVICE"
_DASK_DEVICE = _dask_device()
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
# wrapping or subclassing them. These helper functions can be used instead of
# the wrapper functions for libraries that need to support both NumPy/CuPy and
# other libraries that use devices.
def device(x: Array, /) -> Device:
"""
Hardware device the array data resides on.
This is equivalent to `x.device` according to the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
This helper is included because some array libraries either do not have
the `device` attribute or include it with an incompatible API.
Parameters
----------
x: array
array instance from an array API compatible library.
Returns
-------
out: device
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
section of the array API specification).
Notes
-----
For NumPy the device is always `"cpu"`. For Dask, the device is always a
special `DASK_DEVICE` object.
See Also
--------
to_device : Move array data to a different device.
"""
if is_numpy_array(x):
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the Dask array to determine type
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
elif is_jax_array(x):
# FIXME Jitted JAX arrays do not have a device attribute
# https://github.com/jax-ml/jax/issues/26000
# Return None in this case. Note that this workaround breaks
# the standard and will result in new arrays being created on the
# default device instead of the same device as the input array(s).
x_device = getattr(x, 'device', None)
# Older JAX releases had .device() as a method, which has been replaced
# with a property in accordance with the standard.
if inspect.ismethod(x_device):
return x_device()
else:
return x_device
elif is_pydata_sparse_array(x):
# `sparse` will gain `.device`, so check for this first.
x_device = getattr(x, 'device', None)
if x_device is not None:
return x_device
# Everything but DOK has this attr.
try:
inner = x.data
except AttributeError:
return "cpu"
# Return the device of the constituent array
return device(inner)
return x.device
# Prevent shadowing, used below
_device = device
# Based on cupy.array_api.Array.to_device
def _cupy_to_device(x, device, /, stream=None):
import cupy as cp
from cupy.cuda import Device as _Device
from cupy.cuda import stream as stream_module
from cupy_backends.cuda.api import runtime
if device == x.device:
return x
elif device == "cpu":
# allowing us to use `to_device(x, "cpu")`
# is useful for portable test swapping between
# host and device backends
return x.get()
elif not isinstance(device, _Device):
raise ValueError(f"Unsupported device {device!r}")
else:
# see cupy/cupy#5985 for the reason how we handle device/stream here
prev_device = runtime.getDevice()
prev_stream: stream_module.Stream = None
if stream is not None:
prev_stream = stream_module.get_current_stream()
# stream can be an int as specified in __dlpack__, or a CuPy stream
if isinstance(stream, int):
stream = cp.cuda.ExternalStream(stream)
elif isinstance(stream, cp.cuda.Stream):
pass
else:
raise ValueError('the input stream is not recognized')
stream.use()
try:
runtime.setDevice(device.id)
arr = x.copy()
finally:
runtime.setDevice(prev_device)
if stream is not None:
prev_stream.use()
return arr
def _torch_to_device(x, device, /, stream=None):
if stream is not None:
raise NotImplementedError
return x.to(device)
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
This is equivalent to `x.to_device(device, stream=stream)` according to
the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html>`__.
This helper is included because some array libraries do not have the
`to_device` method.
Parameters
----------
x: array
array instance from an array API compatible library.
device: device
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
section of the array API specification).
stream: Optional[Union[int, Any]]
stream object to use during copy. In addition to the types supported
in ``array.__dlpack__``, implementations may choose to support any
library-specific stream object with the caveat that any code using
such an object would not be portable.
Returns
-------
out: array
an array with the same data and data type as ``x`` and located on the
specified ``device``.
Notes
-----
For NumPy, this function effectively does nothing since the only supported
device is the CPU. For CuPy, this method supports CuPy CUDA
:external+cupy:class:`Device <cupy.cuda.Device>` and
:external+cupy:class:`Stream <cupy.cuda.Stream>` objects. For PyTorch,
this is the same as :external+torch:meth:`x.to(device) <torch.Tensor.to>`
(the ``stream`` argument is not supported in PyTorch).
See Also
--------
device : Hardware device the array data resides on.
"""
if is_numpy_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
elif is_cupy_array(x):
# cupy does not yet have to_device
return _cupy_to_device(x, device, stream=stream)
elif is_torch_array(x):
return _torch_to_device(x, device, stream=stream)
elif is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
elif is_jax_array(x):
if not hasattr(x, "__array_namespace__"):
# In JAX v0.4.31 and older, this import adds to_device method to x...
import jax.experimental.array_api # noqa: F401
# ... but only on eager JAX. It won't work inside jax.jit.
if not hasattr(x, "to_device"):
return x
return x.to_device(device, stream=stream)
elif is_pydata_sparse_array(x) and device == _device(x):
# Perform trivial check to return the same array if
# device is same instead of err-ing.
return x
return x.to_device(device, stream=stream)
def size(x: Array) -> int | None:
"""
Return the total number of elements of x.
This is equivalent to `x.size` according to the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
This helper is included because PyTorch defines `size` in an
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
the standard requires None.
"""
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
if None in x.shape:
return None
out = math.prod(x.shape)
# dask.array.Array.shape can contain NaN
return None if math.isnan(out) else out
def is_writeable_array(x: object) -> bool:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
Return False if `x` is not an array API compatible object.
Warning
-------
As there is no standard way to check if an array is writeable without actually
writing to it, this function blindly returns True for all unknown array types.
"""
if is_numpy_array(x):
return x.flags.writeable
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return is_array_api_obj(x)
def is_lazy_array(x: object) -> bool:
"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.
Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
cheap as long as the array has the right dtype and size.
Note
----
This function errs on the side of caution for array types that may or may not be
lazy, e.g. JAX arrays, by always returning True for them.
"""
if (
is_numpy_array(x)
or is_cupy_array(x)
or is_torch_array(x)
or is_pydata_sparse_array(x)
):
return False
# **JAX note:** while it is possible to determine if you're inside or outside
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
# as we do below for unknown arrays, this is not recommended by JAX best practices.
# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
# This behaviour, while impossible to change without breaking backwards
# compatibility, is highly detrimental to performance as the whole graph will end
# up being computed multiple times.
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
return True
if not is_array_api_obj(x):
return False
# Unknown Array API compatible object. Note that this test may have dire consequences
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
# on __bool__ (dask is one such example, which however is special-cased above).
# Select a single point of the array
s = size(x)
if s is None:
return True
xp = array_namespace(x)
if s > 1:
x = xp.reshape(x, (-1,))[0]
# Cast to dtype=bool and deal with size 0 arrays
x = xp.any(x)
try:
bool(x)
return False
# The Array API standard dictactes that __bool__ should raise TypeError if the
# output cannot be defined.
# Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
except Exception:
return True
__all__ = [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"is_array_api_strict_namespace",
"is_cupy_array",
"is_cupy_namespace",
"is_dask_array",
"is_dask_namespace",
"is_jax_array",
"is_jax_namespace",
"is_numpy_array",
"is_numpy_namespace",
"is_torch_array",
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"is_writeable_array",
"is_lazy_array",
"size",
"to_device",
]
_all_ignore = ['sys', 'math', 'inspect', 'warnings']