88 clip as _aliases_clip ,
99 unstack as _aliases_unstack ,
1010 cumulative_sum as _aliases_cumulative_sum ,
11+ cumulative_prod as _aliases_cumulative_prod ,
1112 )
1213from .._internal import get_xp
1314
@@ -124,7 +125,11 @@ def _fix_promotion(x1, x2, only_scalar=True):
124125 x1 = x1 .to (dtype )
125126 return x1 , x2
126127
127- def result_type (* arrays_and_dtypes : Union [array , Dtype ]) -> Dtype :
128+
129+ _py_scalars = (bool , int , float , complex )
130+
131+
132+ def result_type (* arrays_and_dtypes : Union [array , Dtype , bool , int , float , complex ]) -> Dtype :
128133 if len (arrays_and_dtypes ) == 0 :
129134 raise TypeError ("At least one array or dtype must be provided" )
130135 if len (arrays_and_dtypes ) == 1 :
@@ -136,6 +141,9 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
136141 return result_type (arrays_and_dtypes [0 ], result_type (* arrays_and_dtypes [1 :]))
137142
138143 x , y = arrays_and_dtypes
144+ if isinstance (x , _py_scalars ) or isinstance (y , _py_scalars ):
145+ return torch .result_type (x , y )
146+
139147 xdt = x .dtype if not isinstance (x , torch .dtype ) else x
140148 ydt = y .dtype if not isinstance (y , torch .dtype ) else y
141149
@@ -210,6 +218,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
210218clip = get_xp (torch )(_aliases_clip )
211219unstack = get_xp (torch )(_aliases_unstack )
212220cumulative_sum = get_xp (torch )(_aliases_cumulative_sum )
221+ cumulative_prod = get_xp (torch )(_aliases_cumulative_prod )
213222
214223# torch.sort also returns a tuple
215224# https://github.com/pytorch/pytorch/issues/70921
@@ -504,6 +513,31 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
504513 raise ValueError ("nonzero() does not support zero-dimensional arrays" )
505514 return torch .nonzero (x , as_tuple = True , ** kwargs )
506515
516+
517+ # torch uses `dim` instead of `axis`
518+ def diff (
519+ x : array ,
520+ / ,
521+ * ,
522+ axis : int = - 1 ,
523+ n : int = 1 ,
524+ prepend : Optional [array ] = None ,
525+ append : Optional [array ] = None ,
526+ ) -> array :
527+ return torch .diff (x , dim = axis , n = n , prepend = prepend , append = append )
528+
529+
530+ # torch uses `dim` instead of `axis`
531+ def count_nonzero (
532+ x : array ,
533+ / ,
534+ * ,
535+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
536+ keepdims : bool = False ,
537+ ) -> array :
538+ return torch .count_nonzero (x , dim = axis , keepdims = keepdims )
539+
540+
507541def where (condition : array , x1 : array , x2 : array , / ) -> array :
508542 x1 , x2 = _fix_promotion (x1 , x2 )
509543 return torch .where (condition , x1 , x2 )
@@ -734,6 +768,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
734768 axis = 0
735769 return torch .index_select (x , axis , indices , ** kwargs )
736770
771+
772+ def take_along_axis (x : array , indices : array , / , * , axis : int = - 1 ) -> array :
773+ return torch .take_along_dim (x , indices , dim = axis )
774+
775+
737776def sign (x : array , / ) -> array :
738777 # torch sign() does not support complex numbers and does not propagate
739778 # nans. See https://github.com/data-apis/array-api-compat/issues/136
@@ -752,18 +791,19 @@ def sign(x: array, /) -> array:
752791__all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
753792 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
754793 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
755- 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'divide' ,
794+ 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
795+ 'diff' , 'divide' ,
756796 'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
757797 'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
758798 'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
759- 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'sort' , 'prod' , 'sum' ,
799+ 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'cumulative_prod' , ' sort' , 'prod' , 'sum' ,
760800 'any' , 'all' , 'mean' , 'std' , 'var' , 'concat' , 'squeeze' ,
761801 'broadcast_to' , 'flip' , 'roll' , 'nonzero' , 'where' , 'reshape' ,
762802 'arange' , 'eye' , 'linspace' , 'full' , 'ones' , 'zeros' , 'empty' ,
763803 'tril' , 'triu' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
764804 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
765805 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
766806 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
767- 'take' , 'sign' ]
807+ 'take' , 'take_along_axis' , ' sign' ]
768808
769809_all_ignore = ['torch' , 'get_xp' ]
0 commit comments