@@ -505,6 +505,17 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
505505 raise ValueError ("nonzero() does not support zero-dimensional arrays" )
506506 return torch .nonzero (x , as_tuple = True , ** kwargs )
507507
508+ # torch uses `dim` instead of `axis`
509+ def count_nonzero (
510+ x : array ,
511+ / ,
512+ * ,
513+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
514+ keepdims : bool = False ,
515+ ) -> array :
516+ return torch .count_nonzero (x , dim = axis , keepdims = keepdims )
517+
518+
508519def where (condition : array , x1 : array , x2 : array , / ) -> array :
509520 x1 , x2 = _fix_promotion (x1 , x2 )
510521 return torch .where (condition , x1 , x2 )
@@ -753,7 +764,8 @@ def sign(x: array, /) -> array:
753764__all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
754765 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
755766 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
756- 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'divide' ,
767+ 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
768+ 'divide' ,
757769 'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
758770 'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
759771 'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
0 commit comments