@@ -499,6 +499,17 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
499
499
raise ValueError ("nonzero() does not support zero-dimensional arrays" )
500
500
return torch .nonzero (x , as_tuple = True , ** kwargs )
501
501
502
+ # torch uses `dim` instead of `axis`
503
+ def count_nonzero (
504
+ x : array ,
505
+ / ,
506
+ * ,
507
+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
508
+ keepdims : bool = False ,
509
+ ) -> array :
510
+ return torch .count_nonzero (x , dim = axis , keepdims = keepdims )
511
+
512
+
502
513
def where (condition : array , x1 : array , x2 : array , / ) -> array :
503
514
x1 , x2 = _fix_promotion (x1 , x2 )
504
515
return torch .where (condition , x1 , x2 )
@@ -736,7 +747,8 @@ def sign(x: array, /) -> array:
736
747
__all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
737
748
'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
738
749
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
739
- 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'divide' ,
750
+ 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
751
+ 'divide' ,
740
752
'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
741
753
'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
742
754
'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
0 commit comments