@@ -127,6 +127,19 @@ def astype(
127
127
return x .astype (dtype = dtype , copy = copy )
128
128
129
129
130
+ # count_nonzero returns a python int for axis=None and keepdims=False
131
+ # https://github.com/numpy/numpy/issues/17562
132
+ def count_nonzero (
133
+ x : ndarray ,
134
+ axis = None ,
135
+ keepdims = False
136
+ ) -> ndarray :
137
+ result = np .count_nonzero (x , axis = axis , keepdims = keepdims )
138
+ if axis is None and not keepdims :
139
+ return np .asarray (result )
140
+ return result
141
+
142
+
130
143
# These functions are completely new here. If the library already has them
131
144
# (i.e., numpy 2.0), use the library version instead of our wrapper.
132
145
if hasattr (np , 'vecdot' ):
@@ -148,6 +161,6 @@ def astype(
148
161
'acos' , 'acosh' , 'asin' , 'asinh' , 'atan' ,
149
162
'atan2' , 'atanh' , 'bitwise_left_shift' ,
150
163
'bitwise_invert' , 'bitwise_right_shift' ,
151
- 'bool' , 'concat' , 'pow' ]
164
+ 'bool' , 'concat' , 'count_nonzero' , ' pow' ]
152
165
153
166
_all_ignore = ['np' , 'get_xp' ]
0 commit comments