2
2
3
3
from functools import reduce as _reduce , wraps as _wraps
4
4
from builtins import all as _builtin_all , any as _builtin_any
5
- from typing import List , Optional , Sequence , Tuple , Union
5
+ from typing import Any , List , Optional , Sequence , Tuple , Union
6
6
7
7
import torch
8
8
9
9
from .._internal import get_xp
10
10
from ..common import _aliases
11
+ from ..common ._typing import NestedSequence , SupportsBufferProtocol
11
12
from ._info import __array_namespace_info__
12
13
from ._typing import Array , Device , DType
13
14
@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207
208
remainder = _two_arg (torch .remainder )
208
209
subtract = _two_arg (torch .subtract )
209
210
211
+
212
+ def asarray (
213
+ obj : (
214
+ Array
215
+ | bool | int | float | complex
216
+ | NestedSequence [bool | int | float | complex ]
217
+ | SupportsBufferProtocol
218
+ ),
219
+ / ,
220
+ * ,
221
+ dtype : DType | None = None ,
222
+ device : Device | None = None ,
223
+ copy : bool | None = None ,
224
+ ** kwargs : Any ,
225
+ ) -> Array :
226
+ # torch.asarray does not respect input->output device propagation
227
+ # https://github.com/pytorch/pytorch/issues/150199
228
+ if device is None and isinstance (obj , torch .Tensor ):
229
+ device = obj .device
230
+ return torch .asarray (obj , dtype = dtype , device = device , copy = copy , ** kwargs )
231
+
232
+
210
233
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211
234
# of 'axis'.
212
235
@@ -282,7 +305,6 @@ def prod(x: Array,
282
305
dtype : Optional [DType ] = None ,
283
306
keepdims : bool = False ,
284
307
** kwargs ) -> Array :
285
- x = torch .asarray (x )
286
308
ndim = x .ndim
287
309
288
310
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -318,7 +340,6 @@ def sum(x: Array,
318
340
dtype : Optional [DType ] = None ,
319
341
keepdims : bool = False ,
320
342
** kwargs ) -> Array :
321
- x = torch .asarray (x )
322
343
ndim = x .ndim
323
344
324
345
# https://github.com/pytorch/pytorch/issues/29137.
@@ -348,7 +369,6 @@ def any(x: Array,
348
369
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
349
370
keepdims : bool = False ,
350
371
** kwargs ) -> Array :
351
- x = torch .asarray (x )
352
372
ndim = x .ndim
353
373
if axis == ():
354
374
return x .to (torch .bool )
@@ -373,7 +393,6 @@ def all(x: Array,
373
393
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
374
394
keepdims : bool = False ,
375
395
** kwargs ) -> Array :
376
- x = torch .asarray (x )
377
396
ndim = x .ndim
378
397
if axis == ():
379
398
return x .to (torch .bool )
@@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array:
816
835
return out
817
836
818
837
819
- __all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
838
+ __all__ = ['__array_namespace_info__' , 'asarray' , ' result_type' , 'can_cast' ,
820
839
'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
821
840
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
822
841
'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
0 commit comments