Skip to content

Commit 55a72f7

Browse files
committed
torch: add count_nonzero
1 parent 37634ab commit 55a72f7

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

array_api_compat/torch/_aliases.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,17 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
499499
raise ValueError("nonzero() does not support zero-dimensional arrays")
500500
return torch.nonzero(x, as_tuple=True, **kwargs)
501501

502+
# torch uses `dim` instead of `axis`
503+
def count_nonzero(
504+
x: array,
505+
/,
506+
*,
507+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
508+
keepdims: bool = False,
509+
) -> array:
510+
return torch.count_nonzero(x, dim=axis, keepdims=keepdims)
511+
512+
502513
def where(condition: array, x1: array, x2: array, /) -> array:
503514
x1, x2 = _fix_promotion(x1, x2)
504515
return torch.where(condition, x1, x2)
@@ -736,7 +747,8 @@ def sign(x: array, /) -> array:
736747
__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
737748
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
738749
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
739-
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide',
750+
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
751+
'divide',
740752
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
741753
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
742754
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',

0 commit comments

Comments
 (0)