From 5b8c55e689e2cca7c0dc4e72be3ed0d38b993cf0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 3 Mar 2025 17:03:39 +0000 Subject: [PATCH 1/4] BUG: torch: fix count_nonzero(... keepdims=True) --- array_api_compat/torch/_aliases.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a6e833f9..b4786320 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -521,7 +521,7 @@ def diff( return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) -# torch uses `dim` instead of `axis` +# torch uses `dim` instead of `axis`, does not have keepdims def count_nonzero( x: array, /, @@ -529,7 +529,14 @@ def count_nonzero( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> array: - return torch.count_nonzero(x, dim=axis, keepdims=keepdims) + result = torch.count_nonzero(x, dim=axis) + if keepdims: + if axis is not None: + return result.unsqueeze(axis) + return _axis_none_keepdims(result, x.ndim, keepdims) + else: + return result + def where(condition: array, x1: array, x2: array, /) -> array: From 2ca6d738d87eb4c70bcb3c5f22d34fb2575ddfd8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 3 Mar 2025 19:40:24 +0000 Subject: [PATCH 2/4] BUG: cupy: fix count_nonzero(... keepdims=True) --- array_api_compat/cupy/_aliases.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 50331fa0..30d9fe48 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -125,6 +125,20 @@ def astype( return out.copy() if copy and out is x else out +# cupy.count_nonzero does not have keepdims +def count_nonzero( + x: ndarray, + axis=None, + keepdims=False +) -> ndarray: + result = cp.count_nonzero(x, axis) + if keepdims: + if axis is None: + return cp.reshape(result, [1]*x.ndim) + return cp.expand_dims(result, axis) + return result + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -146,6 +160,6 @@ def astype( 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'pow', 'sign'] + 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] _all_ignore = ['cp', 'get_xp'] From f85ac1779fee0216b45dc07893b07988585bac3c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 3 Mar 2025 20:18:11 +0000 Subject: [PATCH 3/4] BUG: numpy: fix count_nonzero it returns a python int for `axis=None and not keepdims` --- array_api_compat/numpy/_aliases.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 98eec121..a47f7121 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -127,6 +127,19 @@ def astype( return x.astype(dtype=dtype, copy=copy) +# count_nonzero returns a python int for axis=None and keepdims=False +# https://github.com/numpy/numpy/issues/17562 +def count_nonzero( + x : ndarray, + axis=None, + keepdims=False +) -> ndarray: + result = np.count_nonzero(x, axis=axis, keepdims=keepdims) + if axis is None and not keepdims: + return np.asarray(result) + return result + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, 'vecdot'): @@ -148,6 +161,6 @@ def astype( 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'pow'] + 'bool', 'concat', 'count_nonzero', 'pow'] _all_ignore = ['np', 'get_xp'] From ea068a013886670475a4e056db8f252080e81e9a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 3 Mar 2025 20:31:55 +0000 Subject: [PATCH 4/4] BUG: dask: fix count_nonzero(... keepdims=True) --- array_api_compat/dask/array/_aliases.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 4e2d26f9..80d66281 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -335,6 +335,21 @@ def argsort( return restore(x) +# dask.array.count_nonzero does not have keepdims +def count_nonzero( + x: Array, + axis=None, + keepdims=False +) -> Array: + result = da.count_nonzero(x, axis) + if keepdims: + if axis is None: + return da.reshape(result, [1]*x.ndim) + return da.expand_dims(result, axis) + return result + + + __all__ = _aliases.__all__ + [ '__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', @@ -343,6 +358,6 @@ def argsort( 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', 'iinfo', 'finfo', - 'can_cast', 'result_type'] + 'can_cast', 'count_nonzero', 'result_type'] _all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]