@@ -368,23 +368,24 @@ def _isscalar(a):
368
368
if type (max ) is int and max >= wrapped_xp .iinfo (x .dtype ).max :
369
369
max = None
370
370
371
+ dev = device (x )
371
372
if out is None :
372
- out = wrapped_xp .asarray (xp .broadcast_to (x , result_shape ),
373
- copy = True , device = device (x ))
373
+ out = wrapped_xp .empty (result_shape , dtype = x .dtype , device = dev )
374
+ out [()] = x
375
+
374
376
if min is not None :
375
- if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (min ):
376
- # Avoid loss of precision due to torch defaulting to float32
377
- min = wrapped_xp .asarray (min , dtype = xp .float64 )
378
- a = xp .broadcast_to (wrapped_xp .asarray (min , device = device (x )), result_shape )
377
+ a = wrapped_xp .asarray (min , dtype = x .dtype , device = dev )
378
+ a = xp .broadcast_to (a , result_shape )
379
379
ia = (out < a ) | xp .isnan (a )
380
380
# torch requires an explicit cast here
381
- out [ia ] = wrapped_xp .astype (a [ia ], out .dtype )
381
+ out [ia ] = a [ia ]
382
+
382
383
if max is not None :
383
- if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (max ):
384
- max = wrapped_xp .asarray (max , dtype = xp .float64 )
385
- b = xp .broadcast_to (wrapped_xp .asarray (max , device = device (x )), result_shape )
384
+ b = wrapped_xp .asarray (max , dtype = x .dtype , device = dev )
385
+ b = xp .broadcast_to (b , result_shape )
386
386
ib = (out > b ) | xp .isnan (b )
387
- out [ib ] = wrapped_xp .astype (b [ib ], out .dtype )
387
+ out [ib ] = b [ib ]
388
+
388
389
# Return a scalar for 0-D
389
390
return out [()]
390
391
0 commit comments