From 090f570ad350b8648c3ffcc610d427ef8181ec00 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 26 Jun 2024 15:00:08 -0600 Subject: [PATCH 1/5] Add a wrapper for torch.copysign --- array_api_compat/torch/_aliases.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index fb53e0ee..1d8b088c 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -155,6 +155,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: bitwise_or = _two_arg(torch.bitwise_or) bitwise_right_shift = _two_arg(torch.bitwise_right_shift) bitwise_xor = _two_arg(torch.bitwise_xor) +copysign = _two_arg(torch.copysign) divide = _two_arg(torch.divide) # Also a rename. torch.equal does not broadcast equal = _two_arg(torch.eq) @@ -702,9 +703,9 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', - 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', - 'equal', 'floor_divide', 'greater', 'greater_equal', 'less', - 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', + 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', + 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', + 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', From 398fdf13ab078cd87e1f10833be533e98c87608e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 9 Jul 2024 15:07:49 -0600 Subject: [PATCH 2/5] Add clip() wrapper for NumPy and CuPy --- array_api_compat/common/_aliases.py | 54 +++++++++++++++++++++++++++-- array_api_compat/cupy/_aliases.py | 1 + array_api_compat/numpy/_aliases.py | 1 + 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 40a2c0b0..a15e7aad 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple import inspect -from ._helpers import _check_device +from ._helpers import array_namespace, _check_device # These functions are modified from the NumPy versions. @@ -264,6 +264,56 @@ def var( ) -> ndarray: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + +# The min and max argument names in clip are different and not optional in numpy, and type +# promotion behavior is different. +def clip( + x: ndarray, + /, + min: Optional[Union[int, float, ndarray]] = None, + max: Optional[Union[int, float, ndarray]] = None, + *, + xp, + # TODO: np.clip has other ufunc kwargs + out: Optional[ndarray] = None, +) -> ndarray: + def _isscalar(a): + return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape + max_shape = () if _isscalar(max) else max.shape + result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape) + + wrapped_xp = array_namespace(x) + + # np.clip does type promotion but the array API clip requires that the + # output have the same dtype as x. We do this instead of just downcasting + # the result of xp.clip() to handle some corner cases better (e.g., + # avoiding uint64 -> float64 promotion). + + # Note: cases where min or max overflow (integer) or round (float) in the + # wrong direction when downcasting to x.dtype are unspecified. This code + # just does whatever NumPy does when it downcasts in the assignment, but + # other behavior could be preferred, especially for integers. For example, + # this code produces: + + # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None) + # -128 + + # but an answer of 0 might be preferred. See + # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. + if out is None: + out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True) + if min is not None: + a = xp.broadcast_to(xp.asarray(min), result_shape) + ia = (out < a) | xp.isnan(a) + out[ia] = a[ia] + if max is not None: + b = xp.broadcast_to(xp.asarray(max), result_shape) + ib = (out > b) | xp.isnan(b) + out[ib] = b[ib] + # Return a scalar for 0-D + return out[()] + # Unlike transpose(), the axes argument to permute_dims() is required. def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: return xp.transpose(x, axes) @@ -465,6 +515,6 @@ def isdtype( 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', + 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index b1b513e6..d7e78fdd 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -47,6 +47,7 @@ astype = _aliases.astype std = get_xp(cp)(_aliases.std) var = get_xp(cp)(_aliases.var) +clip = get_xp(cp)(_aliases.clip) permute_dims = get_xp(cp)(_aliases.permute_dims) reshape = get_xp(cp)(_aliases.reshape) argsort = get_xp(cp)(_aliases.argsort) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 70378716..e29b0751 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -47,6 +47,7 @@ astype = _aliases.astype std = get_xp(np)(_aliases.std) var = get_xp(np)(_aliases.var) +clip = get_xp(np)(_aliases.clip) permute_dims = get_xp(np)(_aliases.permute_dims) reshape = get_xp(np)(_aliases.reshape) argsort = get_xp(np)(_aliases.argsort) From 4d591d0ac20e0bcc7fb4b150d86b41952fd644d3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 15 Jul 2024 15:07:31 -0600 Subject: [PATCH 3/5] Handle large Python integers better in clip --- array_api_compat/common/_aliases.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index a15e7aad..e82ed5a5 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -301,6 +301,15 @@ def _isscalar(a): # but an answer of 0 might be preferred. See # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. + + + # At least handle the case of Python integers correctly (see + # https://github.com/numpy/numpy/pull/26892). + if type(min) is int and min <= xp.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= xp.iinfo(x.dtype).max: + max = None + if out is None: out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True) if min is not None: From 7da965cf4874e006c8c1dc9faa0dffaef4075527 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Jul 2024 12:16:15 -0600 Subject: [PATCH 4/5] Fix clip test failures for dask (xfailed for now) --- array_api_compat/dask/array/_aliases.py | 1 + dask-xfails.txt | 3 +++ 2 files changed, 4 insertions(+) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d2aac8b2..d26ec6a2 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -88,6 +88,7 @@ def _dask_arange( permute_dims = get_xp(da)(_aliases.permute_dims) std = get_xp(da)(_aliases.std) var = get_xp(da)(_aliases.var) +clip = get_xp(da)(_aliases.clip) empty = get_xp(da)(_aliases.empty) empty_like = get_xp(da)(_aliases.empty_like) full = get_xp(da)(_aliases.full) diff --git a/dask-xfails.txt b/dask-xfails.txt index 0d74ecbb..64be2e12 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -48,6 +48,9 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +# The clip helper uses boolean indexing +array_api_tests/test_operators_and_elementwise_functions.py::test_clip + # No sorting in dask array_api_tests/test_has_names.py::test_has_names[sorting-argsort] array_api_tests/test_has_names.py::test_has_names[sorting-sort] From b8a59daa5b7bf5cf75e5d22b4f09440834aa2e04 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Jul 2024 15:32:10 -0600 Subject: [PATCH 5/5] Wrap clip() for torch --- array_api_compat/common/_aliases.py | 5 +++-- array_api_compat/torch/_aliases.py | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index e82ed5a5..bb53e8e5 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -315,11 +315,12 @@ def _isscalar(a): if min is not None: a = xp.broadcast_to(xp.asarray(min), result_shape) ia = (out < a) | xp.isnan(a) - out[ia] = a[ia] + # torch requires an explicit cast here + out[ia] = wrapped_xp.astype(a[ia], out.dtype) if max is not None: b = xp.broadcast_to(xp.asarray(max), result_shape) ib = (out > b) | xp.isnan(b) - out[ib] = b[ib] + out[ib] = wrapped_xp.astype(b[ib], out.dtype) # Return a scalar for 0-D return out[()] diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 1d8b088c..c2be21fe 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -4,7 +4,7 @@ from builtins import all as _builtin_all, any as _builtin_any from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot) + vecdot as _aliases_vecdot, clip as _aliases_clip) from .._internal import get_xp import torch @@ -189,6 +189,8 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep return torch.clone(x) return torch.amin(x, axis, keepdims=keepdims) +clip = get_xp(torch)(_aliases_clip) + # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: @@ -706,8 +708,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', - 'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum', - 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', + 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod', + 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',