|
14 | 14 | from ._array_object import Array
|
15 | 15 | from ._flags import requires_api_version
|
16 | 16 | from ._creation_functions import asarray
|
| 17 | +from ._data_type_functions import broadcast_to, iinfo |
17 | 18 |
|
18 | 19 | from typing import Optional, Union
|
19 | 20 |
|
@@ -325,14 +326,51 @@ def clip(
|
325 | 326 | if min is not None and max is not None and np.any(min > max):
|
326 | 327 | raise ValueError("min must be less than or equal to max")
|
327 | 328 |
|
328 |
| - result = np.clip(x._array, min, max) |
329 |
| - # Note: NumPy applies type promotion, but the standard specifies the |
330 |
| - # return dtype should be the same as x |
331 |
| - if result.dtype != x.dtype._np_dtype: |
332 |
| - # TODO: I'm not completely sure this always gives the correct thing |
333 |
| - # for integer dtypes. See https://github.com/numpy/numpy/issues/24976 |
334 |
| - result = result.astype(x.dtype._np_dtype) |
335 |
| - return Array._new(result, device=x.device) |
| 329 | + # np.clip does type promotion but the array API clip requires that the |
| 330 | + # output have the same dtype as x. We do this instead of just downcasting |
| 331 | + # the result of xp.clip() to handle some corner cases better (e.g., |
| 332 | + # avoiding uint64 -> float64 promotion). |
| 333 | + |
| 334 | + # Note: cases where min or max overflow (integer) or round (float) in the |
| 335 | + # wrong direction when downcasting to x.dtype are unspecified. This code |
| 336 | + # just does whatever NumPy does when it downcasts in the assignment, but |
| 337 | + # other behavior could be preferred, especially for integers. For example, |
| 338 | + # this code produces: |
| 339 | + |
| 340 | + # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None) |
| 341 | + # -128 |
| 342 | + |
| 343 | + # but an answer of 0 might be preferred. See |
| 344 | + # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. |
| 345 | + |
| 346 | + # At least handle the case of Python integers correctly (see |
| 347 | + # https://github.com/numpy/numpy/pull/26892). |
| 348 | + if type(min) is int and min <= iinfo(x.dtype).min: |
| 349 | + min = None |
| 350 | + if type(max) is int and max >= iinfo(x.dtype).max: |
| 351 | + max = None |
| 352 | + |
| 353 | + def _isscalar(a): |
| 354 | + return isinstance(a, (int, float, type(None))) |
| 355 | + min_shape = () if _isscalar(min) else min.shape |
| 356 | + max_shape = () if _isscalar(max) else max.shape |
| 357 | + |
| 358 | + result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) |
| 359 | + |
| 360 | + out = asarray(broadcast_to(x, result_shape), copy=True)._array |
| 361 | + device = x.device |
| 362 | + x = x._array |
| 363 | + |
| 364 | + if min is not None: |
| 365 | + a = np.broadcast_to(np.asarray(min), result_shape) |
| 366 | + ia = (out < a) | np.isnan(a) |
| 367 | + |
| 368 | + out[ia] = a[ia] |
| 369 | + if max is not None: |
| 370 | + b = np.broadcast_to(np.asarray(max), result_shape) |
| 371 | + ib = (out > b) | np.isnan(b) |
| 372 | + out[ib] = b[ib] |
| 373 | + return Array._new(out, device=device) |
336 | 374 |
|
337 | 375 | def conj(x: Array, /) -> Array:
|
338 | 376 | """
|
|
0 commit comments