Skip to content

Commit 5f8b5d6

Browse files
committed
Improvements to the clip wrapper
- Ensure the arrays that are created are created on the same device as x. (fixes #177) - Make clip() work with dask.array. The workaround avoid uint64 -> float64 promotion does not work here. (fixes #176) - Fix loss of precision when clipping a float64 tensor with torch due to the scalar being converted to a float32 tensor.
1 parent b96e84b commit 5f8b5d6

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

Diff for: array_api_compat/common/_aliases.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import array_namespace, _check_device
15+
from ._helpers import array_namespace, _check_device, device, is_torch_array
1616

1717
# These functions are modified from the NumPy versions.
1818

@@ -281,10 +281,11 @@ def _isscalar(a):
281281
return isinstance(a, (int, float, type(None)))
282282
min_shape = () if _isscalar(min) else min.shape
283283
max_shape = () if _isscalar(max) else max.shape
284-
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
285284

286285
wrapped_xp = array_namespace(x)
287286

287+
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
288+
288289
# np.clip does type promotion but the array API clip requires that the
289290
# output have the same dtype as x. We do this instead of just downcasting
290291
# the result of xp.clip() to handle some corner cases better (e.g.,
@@ -305,20 +306,26 @@ def _isscalar(a):
305306

306307
# At least handle the case of Python integers correctly (see
307308
# https://github.com/numpy/numpy/pull/26892).
308-
if type(min) is int and min <= xp.iinfo(x.dtype).min:
309+
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
309310
min = None
310-
if type(max) is int and max >= xp.iinfo(x.dtype).max:
311+
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
311312
max = None
312313

313314
if out is None:
314-
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
315+
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
316+
copy=True, device=device(x))
315317
if min is not None:
316-
a = xp.broadcast_to(xp.asarray(min), result_shape)
318+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
319+
# Avoid loss of precision due to torch defaulting to float32
320+
min = wrapped_xp.asarray(min, dtype=xp.float64)
321+
a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
317322
ia = (out < a) | xp.isnan(a)
318323
# torch requires an explicit cast here
319324
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
320325
if max is not None:
321-
b = xp.broadcast_to(xp.asarray(max), result_shape)
326+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
327+
max = wrapped_xp.asarray(max, dtype=xp.float64)
328+
b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
322329
ib = (out > b) | xp.isnan(b)
323330
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
324331
# Return a scalar for 0-D

Diff for: array_api_compat/dask/array/_aliases.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def _dask_arange(
8989
permute_dims = get_xp(da)(_aliases.permute_dims)
9090
std = get_xp(da)(_aliases.std)
9191
var = get_xp(da)(_aliases.var)
92-
clip = get_xp(da)(_aliases.clip)
9392
empty = get_xp(da)(_aliases.empty)
9493
empty_like = get_xp(da)(_aliases.empty_like)
9594
full = get_xp(da)(_aliases.full)
@@ -167,6 +166,43 @@ def asarray(
167166
concatenate as concat,
168167
)
169168

169+
# dask.array.clip does not work unless all three arguments are provided.
170+
# Furthermore, the masking workaround in common._aliases.clip cannot work with
171+
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
172+
# now).
173+
@get_xp(da)
174+
def clip(
175+
x: Array,
176+
/,
177+
min: Optional[Union[int, float, Array]] = None,
178+
max: Optional[Union[int, float, Array]] = None,
179+
*,
180+
xp,
181+
) -> Array:
182+
def _isscalar(a):
183+
return isinstance(a, (int, float, type(None)))
184+
min_shape = () if _isscalar(min) else min.shape
185+
max_shape = () if _isscalar(max) else max.shape
186+
187+
# TODO: This won't handle dask unknown shapes
188+
import numpy as np
189+
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
190+
191+
if min is not None:
192+
min = xp.broadcast_to(xp.asarray(min), result_shape)
193+
if max is not None:
194+
max = xp.broadcast_to(xp.asarray(max), result_shape)
195+
196+
if min is None and max is None:
197+
return xp.positive(x)
198+
199+
if min is None:
200+
return astype(xp.minimum(x, max), x.dtype)
201+
if max is None:
202+
return astype(xp.maximum(x, min), x.dtype)
203+
204+
return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)
205+
170206
# exclude these from all since
171207
_da_unsupported = ['sort', 'argsort']
172208

0 commit comments

Comments
 (0)