|
12 | 12 | from typing import NamedTuple
|
13 | 13 | import inspect
|
14 | 14 |
|
15 |
| -from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace |
| 15 | +from ._helpers import array_namespace, _check_device, device, is_cupy_namespace |
16 | 16 |
|
17 | 17 | # These functions are modified from the NumPy versions.
|
18 | 18 |
|
@@ -368,23 +368,23 @@ def _isscalar(a):
|
368 | 368 | if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
|
369 | 369 | max = None
|
370 | 370 |
|
| 371 | + dev = device(x) |
371 | 372 | if out is None:
|
372 |
| - out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), |
373 |
| - copy=True, device=device(x)) |
| 373 | + out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) |
| 374 | + out[()] = x |
| 375 | + |
374 | 376 | if min is not None:
|
375 |
| - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min): |
376 |
| - # Avoid loss of precision due to torch defaulting to float32 |
377 |
| - min = wrapped_xp.asarray(min, dtype=xp.float64) |
378 |
| - a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape) |
| 377 | + a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev) |
| 378 | + a = xp.broadcast_to(a, result_shape) |
379 | 379 | ia = (out < a) | xp.isnan(a)
|
380 |
| - # torch requires an explicit cast here |
381 |
| - out[ia] = wrapped_xp.astype(a[ia], out.dtype) |
| 380 | + out[ia] = a[ia] |
| 381 | + |
382 | 382 | if max is not None:
|
383 |
| - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max): |
384 |
| - max = wrapped_xp.asarray(max, dtype=xp.float64) |
385 |
| - b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape) |
| 383 | + b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev) |
| 384 | + b = xp.broadcast_to(b, result_shape) |
386 | 385 | ib = (out > b) | xp.isnan(b)
|
387 |
| - out[ib] = wrapped_xp.astype(b[ib], out.dtype) |
| 386 | + out[ib] = b[ib] |
| 387 | + |
388 | 388 | # Return a scalar for 0-D
|
389 | 389 | return out[()]
|
390 | 390 |
|
|
0 commit comments