forked from data-apis/array-api-strict
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path_creation_functions.py
455 lines (362 loc) · 12.7 KB
/
_creation_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
if TYPE_CHECKING:
from ._typing import (
Array,
Device,
Dtype,
NestedSequence,
SupportsBufferProtocol,
)
from ._dtypes import _DType, _all_dtypes
from ._flags import get_array_api_strict_flags
import numpy as np
@contextmanager
def allow_array():
"""
Temporarily enable Array.__array__. This is needed for np.array to parse
list of lists of Array objects.
"""
from . import _array_object
original_value = _array_object._allow_array
try:
_array_object._allow_array = True
yield
finally:
_array_object._allow_array = original_value
def _check_valid_dtype(dtype):
# Note: Only spelling dtypes as the dtype objects is supported.
if dtype not in (None,) + _all_dtypes:
raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}")
def _supports_buffer_protocol(obj):
try:
memoryview(obj)
except TypeError:
return False
return True
def _check_device(device):
# _array_object imports in this file are inside the functions to avoid
# circular imports
from ._array_object import Device, ALL_DEVICES
if device is not None and not isinstance(device, Device):
raise ValueError(f"Unsupported device {device!r}")
if device is not None and device not in ALL_DEVICES:
raise ValueError(f"Unsupported device {device!r}")
def asarray(
obj: Union[
Array,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: Optional[bool] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_np_dtype = None
if dtype is not None:
_np_dtype = dtype._np_dtype
_check_device(device)
if isinstance(obj, Array) and device is None:
device = obj.device
if np.__version__[0] < '2':
if copy is False:
# Note: copy=False is not yet implemented in np.asarray for
# NumPy 1
# Work around it by creating the new array and seeing if NumPy
# copies it.
if isinstance(obj, Array):
new_array = np.array(obj._array, copy=copy, dtype=_np_dtype)
if new_array is not obj._array:
raise ValueError("Unable to avoid copy while creating an array from given array.")
return Array._new(new_array, device=device)
elif _supports_buffer_protocol(obj):
# Buffer protocol will always support no-copy
return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device)
else:
# No-copy is unsupported for Python built-in types.
raise ValueError("Unable to avoid copy while creating an array from given object.")
if copy is None:
# NumPy 1 treats copy=False the same as the standard copy=None
copy = False
if isinstance(obj, Array):
return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device)
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
# Give a better error message in this case. NumPy would convert this
# to an object array. TODO: This won't handle large integers in lists.
raise OverflowError("Integer out of bounds for array dtypes")
with allow_array():
res = np.array(obj, dtype=_np_dtype, copy=copy)
return Array._new(res, device=device)
def arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if dtype is not None:
dtype = dtype._np_dtype
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype), device=device)
def empty(
shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if dtype is not None:
dtype = dtype._np_dtype
return Array._new(np.empty(shape, dtype=dtype), device=device)
def empty_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is not None:
dtype = dtype._np_dtype
return Array._new(np.empty_like(x._array, dtype=dtype), device=device)
def eye(
n_rows: int,
n_cols: Optional[int] = None,
/,
*,
k: int = 0,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if dtype is not None:
dtype = dtype._np_dtype
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype), device=device)
_default = object()
def from_dlpack(
x: object,
/,
*,
device: Optional[Device] = _default,
copy: Optional[bool] = _default,
) -> Array:
from ._array_object import Array
if get_array_api_strict_flags()['api_version'] < '2023.12':
if device is not _default:
raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API")
if copy is not _default:
raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API")
# Going to wait for upstream numpy support
if device is not _default:
_check_device(device)
if copy not in [_default, None]:
raise NotImplementedError("The copy argument to from_dlpack is not yet implemented")
return Array._new(np.from_dlpack(x), device=device)
def full(
shape: Union[int, Tuple[int, ...]],
fill_value: Union[int, float],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.full <numpy.full>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if isinstance(fill_value, Array) and fill_value.ndim == 0:
fill_value = fill_value._array
if dtype is not None:
dtype = dtype._np_dtype
res = np.full(shape, fill_value, dtype=dtype)
if _DType(res.dtype) not in _all_dtypes:
# This will happen if the fill value is not something that NumPy
# coerces to one of the acceptable dtypes.
raise TypeError("Invalid input to full")
return Array._new(res, device=device)
def full_like(
x: Array,
/,
fill_value: Union[int, float],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is not None:
dtype = dtype._np_dtype
res = np.full_like(x._array, fill_value, dtype=dtype)
if _DType(res.dtype) not in _all_dtypes:
# This will happen if the fill value is not something that NumPy
# coerces to one of the acceptable dtypes.
raise TypeError("Invalid input to full_like")
return Array._new(res, device=device)
def linspace(
start: Union[int, float],
stop: Union[int, float],
/,
num: int,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
endpoint: bool = True,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if dtype is not None:
dtype = dtype._np_dtype
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint), device=device)
def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
See its docstring for more information.
"""
from ._array_object import Array
# Note: unlike np.meshgrid, only inputs with all the same dtype are
# allowed
if len({a.dtype for a in arrays}) > 1:
raise ValueError("meshgrid inputs must all have the same dtype")
if len({a.device for a in arrays}) > 1:
raise ValueError("meshgrid inputs must all be on the same device")
# arrays is allowed to be empty
if arrays:
device = arrays[0].device
else:
device = None
return [
Array._new(array, device=device)
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
]
def ones(
shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if dtype is not None:
dtype = dtype._np_dtype
return Array._new(np.ones(shape, dtype=dtype), device=device)
def ones_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is not None:
dtype = dtype._np_dtype
return Array._new(np.ones_like(x._array, dtype=dtype), device=device)
def tril(x: Array, /, *, k: int = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.tril <numpy.tril>`.
See its docstring for more information.
"""
from ._array_object import Array
if x.ndim < 2:
# Note: Unlike np.tril, x must be at least 2-D
raise ValueError("x must be at least 2-dimensional for tril")
return Array._new(np.tril(x._array, k=k), device=x.device)
def triu(x: Array, /, *, k: int = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.triu <numpy.triu>`.
See its docstring for more information.
"""
from ._array_object import Array
if x.ndim < 2:
# Note: Unlike np.triu, x must be at least 2-D
raise ValueError("x must be at least 2-dimensional for triu")
return Array._new(np.triu(x._array, k=k), device=x.device)
def zeros(
shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if dtype is not None:
dtype = dtype._np_dtype
return Array._new(np.zeros(shape, dtype=dtype), device=device)
def zeros_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.
See its docstring for more information.
"""
from ._array_object import Array
_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is not None:
dtype = dtype._np_dtype
return Array._new(np.zeros_like(x._array, dtype=dtype), device=device)