@@ -761,6 +761,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
761761 axis = 0
762762 return torch .index_select (x , axis , indices , ** kwargs )
763763
764+
765+ def take_along_axis (x : array , indices : array , / , * , axis : int = - 1 ) -> array :
766+ return torch .take_along_dim (x , indices , dim = axis )
767+
768+
764769def sign (x : array , / ) -> array :
765770 # torch sign() does not support complex numbers and does not propagate
766771 # nans. See https://github.com/data-apis/array-api-compat/issues/136
@@ -784,14 +789,14 @@ def sign(x: array, /) -> array:
784789 'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
785790 'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
786791 'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
787- 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'sort' , 'prod' , 'sum' ,
792+ 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'cumulative_prod' , ' sort' , 'prod' , 'sum' ,
788793 'any' , 'all' , 'mean' , 'std' , 'var' , 'concat' , 'squeeze' ,
789794 'broadcast_to' , 'flip' , 'roll' , 'nonzero' , 'where' , 'reshape' ,
790795 'arange' , 'eye' , 'linspace' , 'full' , 'ones' , 'zeros' , 'empty' ,
791796 'tril' , 'triu' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
792797 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
793798 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
794799 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
795- 'take' , 'sign' ]
800+ 'take' , 'take_along_axis' , ' sign' ]
796801
797802_all_ignore = ['torch' , 'get_xp' ]
0 commit comments