Skip to content

Commit 5485345

Browse files
committed
Merge branch 'main' into rm-__array__
2 parents d630ee5 + dec8c22 commit 5485345

File tree

1 file changed

+46
-8
lines changed

1 file changed

+46
-8
lines changed

array_api_strict/_elementwise_functions.py

+46-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ._array_object import Array
1515
from ._flags import requires_api_version
1616
from ._creation_functions import asarray
17+
from ._data_type_functions import broadcast_to, iinfo
1718

1819
from typing import Optional, Union
1920

@@ -325,14 +326,51 @@ def clip(
325326
if min is not None and max is not None and np.any(min > max):
326327
raise ValueError("min must be less than or equal to max")
327328

328-
result = np.clip(x._array, min, max)
329-
# Note: NumPy applies type promotion, but the standard specifies the
330-
# return dtype should be the same as x
331-
if result.dtype != x.dtype._np_dtype:
332-
# TODO: I'm not completely sure this always gives the correct thing
333-
# for integer dtypes. See https://github.com/numpy/numpy/issues/24976
334-
result = result.astype(x.dtype._np_dtype)
335-
return Array._new(result, device=x.device)
329+
# np.clip does type promotion but the array API clip requires that the
330+
# output have the same dtype as x. We do this instead of just downcasting
331+
# the result of xp.clip() to handle some corner cases better (e.g.,
332+
# avoiding uint64 -> float64 promotion).
333+
334+
# Note: cases where min or max overflow (integer) or round (float) in the
335+
# wrong direction when downcasting to x.dtype are unspecified. This code
336+
# just does whatever NumPy does when it downcasts in the assignment, but
337+
# other behavior could be preferred, especially for integers. For example,
338+
# this code produces:
339+
340+
# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
341+
# -128
342+
343+
# but an answer of 0 might be preferred. See
344+
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
345+
346+
# At least handle the case of Python integers correctly (see
347+
# https://github.com/numpy/numpy/pull/26892).
348+
if type(min) is int and min <= iinfo(x.dtype).min:
349+
min = None
350+
if type(max) is int and max >= iinfo(x.dtype).max:
351+
max = None
352+
353+
def _isscalar(a):
354+
return isinstance(a, (int, float, type(None)))
355+
min_shape = () if _isscalar(min) else min.shape
356+
max_shape = () if _isscalar(max) else max.shape
357+
358+
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
359+
360+
out = asarray(broadcast_to(x, result_shape), copy=True)._array
361+
device = x.device
362+
x = x._array
363+
364+
if min is not None:
365+
a = np.broadcast_to(np.asarray(min), result_shape)
366+
ia = (out < a) | np.isnan(a)
367+
368+
out[ia] = a[ia]
369+
if max is not None:
370+
b = np.broadcast_to(np.asarray(max), result_shape)
371+
ib = (out > b) | np.isnan(b)
372+
out[ib] = b[ib]
373+
return Array._new(out, device=device)
336374

337375
def conj(x: Array, /) -> Array:
338376
"""

0 commit comments

Comments
 (0)