12
12
from typing import NamedTuple
13
13
import inspect
14
14
15
- from ._helpers import array_namespace , _check_device
15
+ from ._helpers import array_namespace , _check_device , device , is_torch_array
16
16
17
17
# These functions are modified from the NumPy versions.
18
18
@@ -281,10 +281,11 @@ def _isscalar(a):
281
281
return isinstance (a , (int , float , type (None )))
282
282
min_shape = () if _isscalar (min ) else min .shape
283
283
max_shape = () if _isscalar (max ) else max .shape
284
- result_shape = xp .broadcast_shapes (x .shape , min_shape , max_shape )
285
284
286
285
wrapped_xp = array_namespace (x )
287
286
287
+ result_shape = xp .broadcast_shapes (x .shape , min_shape , max_shape )
288
+
288
289
# np.clip does type promotion but the array API clip requires that the
289
290
# output have the same dtype as x. We do this instead of just downcasting
290
291
# the result of xp.clip() to handle some corner cases better (e.g.,
@@ -305,20 +306,26 @@ def _isscalar(a):
305
306
306
307
# At least handle the case of Python integers correctly (see
307
308
# https://github.com/numpy/numpy/pull/26892).
308
- if type (min ) is int and min <= xp .iinfo (x .dtype ).min :
309
+ if type (min ) is int and min <= wrapped_xp .iinfo (x .dtype ).min :
309
310
min = None
310
- if type (max ) is int and max >= xp .iinfo (x .dtype ).max :
311
+ if type (max ) is int and max >= wrapped_xp .iinfo (x .dtype ).max :
311
312
max = None
312
313
313
314
if out is None :
314
- out = wrapped_xp .asarray (xp .broadcast_to (x , result_shape ), copy = True )
315
+ out = wrapped_xp .asarray (xp .broadcast_to (x , result_shape ),
316
+ copy = True , device = device (x ))
315
317
if min is not None :
316
- a = xp .broadcast_to (xp .asarray (min ), result_shape )
318
+ if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (min ):
319
+ # Avoid loss of precision due to torch defaulting to float32
320
+ min = wrapped_xp .asarray (min , dtype = xp .float64 )
321
+ a = xp .broadcast_to (wrapped_xp .asarray (min , device = device (x )), result_shape )
317
322
ia = (out < a ) | xp .isnan (a )
318
323
# torch requires an explicit cast here
319
324
out [ia ] = wrapped_xp .astype (a [ia ], out .dtype )
320
325
if max is not None :
321
- b = xp .broadcast_to (xp .asarray (max ), result_shape )
326
+ if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (max ):
327
+ max = wrapped_xp .asarray (max , dtype = xp .float64 )
328
+ b = xp .broadcast_to (wrapped_xp .asarray (max , device = device (x )), result_shape )
322
329
ib = (out > b ) | xp .isnan (b )
323
330
out [ib ] = wrapped_xp .astype (b [ib ], out .dtype )
324
331
# Return a scalar for 0-D
0 commit comments