@@ -499,6 +499,17 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
499499 raise ValueError ("nonzero() does not support zero-dimensional arrays" )
500500 return torch .nonzero (x , as_tuple = True , ** kwargs )
501501
502+ # torch uses `dim` instead of `axis`
503+ def count_nonzero (
504+ x : array ,
505+ / ,
506+ * ,
507+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
508+ keepdims : bool = False ,
509+ ) -> array :
510+ return torch .count_nonzero (x , dim = axis , keepdims = keepdims )
511+
512+
502513def where (condition : array , x1 : array , x2 : array , / ) -> array :
503514 x1 , x2 = _fix_promotion (x1 , x2 )
504515 return torch .where (condition , x1 , x2 )
@@ -736,7 +747,8 @@ def sign(x: array, /) -> array:
736747__all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
737748 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
738749 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
739- 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'divide' ,
750+ 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
751+ 'divide' ,
740752 'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
741753 'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
742754 'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
0 commit comments