Skip to content

Commit f85ac17

Browse files
committed
BUG: numpy: fix count_nonzero
it returns a python int for `axis=None and not keepdims`
1 parent 2ca6d73 commit f85ac17

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

Diff for: array_api_compat/numpy/_aliases.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,19 @@ def astype(
127127
return x.astype(dtype=dtype, copy=copy)
128128

129129

130+
# count_nonzero returns a python int for axis=None and keepdims=False
131+
# https://github.com/numpy/numpy/issues/17562
132+
def count_nonzero(
133+
x : ndarray,
134+
axis=None,
135+
keepdims=False
136+
) -> ndarray:
137+
result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
138+
if axis is None and not keepdims:
139+
return np.asarray(result)
140+
return result
141+
142+
130143
# These functions are completely new here. If the library already has them
131144
# (i.e., numpy 2.0), use the library version instead of our wrapper.
132145
if hasattr(np, 'vecdot'):
@@ -148,6 +161,6 @@ def astype(
148161
'acos', 'acosh', 'asin', 'asinh', 'atan',
149162
'atan2', 'atanh', 'bitwise_left_shift',
150163
'bitwise_invert', 'bitwise_right_shift',
151-
'bool', 'concat', 'pow']
164+
'bool', 'concat', 'count_nonzero', 'pow']
152165

153166
_all_ignore = ['np', 'get_xp']

0 commit comments

Comments
 (0)