Skip to content

Commit e14754b

Browse files
authored
BUG: clip(out=...) is broken (#261)
reviewed at #261
1 parent b0cc9dc commit e14754b

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

array_api_compat/common/_aliases.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace
15+
from ._helpers import array_namespace, _check_device, device, is_cupy_namespace
1616

1717
# These functions are modified from the NumPy versions.
1818

@@ -368,23 +368,23 @@ def _isscalar(a):
368368
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
369369
max = None
370370

371+
dev = device(x)
371372
if out is None:
372-
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
373-
copy=True, device=device(x))
373+
out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
374+
out[()] = x
375+
374376
if min is not None:
375-
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
376-
# Avoid loss of precision due to torch defaulting to float32
377-
min = wrapped_xp.asarray(min, dtype=xp.float64)
378-
a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
377+
a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev)
378+
a = xp.broadcast_to(a, result_shape)
379379
ia = (out < a) | xp.isnan(a)
380-
# torch requires an explicit cast here
381-
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
380+
out[ia] = a[ia]
381+
382382
if max is not None:
383-
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
384-
max = wrapped_xp.asarray(max, dtype=xp.float64)
385-
b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
383+
b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev)
384+
b = xp.broadcast_to(b, result_shape)
386385
ib = (out > b) | xp.isnan(b)
387-
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
386+
out[ib] = b[ib]
387+
388388
# Return a scalar for 0-D
389389
return out[()]
390390

tests/test_common.py

+15
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,18 @@ def test_asarray_copy(library):
367367
assert all(b[0] == 1.0)
368368
else:
369369
assert all(b[0] == 0.0)
370+
371+
372+
@pytest.mark.parametrize("library", ["numpy", "cupy", "torch"])
373+
def test_clip_out(library):
374+
"""Test non-standard out= parameter for clip()
375+
376+
(see "Avoid Restricting Behavior that is Outside the Scope of the Standard"
377+
in https://data-apis.org/array-api-compat/dev/special-considerations.html)
378+
"""
379+
xp = import_(library, wrapper=True)
380+
x = xp.asarray([10, 20, 30])
381+
out = xp.zeros_like(x)
382+
xp.clip(x, 15, 25, out=out)
383+
expect = xp.asarray([15, 20, 25])
384+
assert xp.all(out == expect)

0 commit comments

Comments
 (0)