Skip to content

Commit ea068a0

Browse files
committed
BUG: dask: fix count_nonzero(... keepdims=True)
1 parent f85ac17 commit ea068a0

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

array_api_compat/dask/array/_aliases.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,21 @@ def argsort(
335335
return restore(x)
336336

337337

338+
# dask.array.count_nonzero does not have keepdims
339+
def count_nonzero(
340+
x: Array,
341+
axis=None,
342+
keepdims=False
343+
) -> Array:
344+
result = da.count_nonzero(x, axis)
345+
if keepdims:
346+
if axis is None:
347+
return da.reshape(result, [1]*x.ndim)
348+
return da.expand_dims(result, axis)
349+
return result
350+
351+
352+
338353
__all__ = _aliases.__all__ + [
339354
'__array_namespace_info__', 'asarray', 'astype', 'acos',
340355
'acosh', 'asin', 'asinh', 'atan', 'atan2',
@@ -343,6 +358,6 @@ def argsort(
343358
'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
344359
'uint8', 'uint16', 'uint32', 'uint64',
345360
'complex64', 'complex128', 'iinfo', 'finfo',
346-
'can_cast', 'result_type']
361+
'can_cast', 'count_nonzero', 'result_type']
347362

348363
_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]

0 commit comments

Comments
 (0)