@@ -241,6 +241,20 @@ def sort(
241241) -> Array :
242242 return torch .sort (x , dim = axis , descending = descending , stable = stable , ** kwargs ).values
243243
244+
245+ # Wrap torch.argsort to set stable=True by default
246+ def argsort (
247+ x : Array ,
248+ / ,
249+ * ,
250+ axis : int = - 1 ,
251+ descending : bool = False ,
252+ stable : bool = True ,
253+ ** kwargs : object ,
254+ ) -> Array :
255+ return torch .argsort (x , dim = axis , descending = descending , stable = stable , ** kwargs )
256+
257+
244258def _normalize_axes (axis , ndim ):
245259 axes = []
246260 if ndim == 0 and axis :
@@ -837,9 +851,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array
837851 'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
838852 'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
839853 'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
840- 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'cumulative_prod' , 'sort' , 'prod' , 'sum' ,
841- 'any' , 'all' , 'mean' , 'std' , 'var' , 'concat' , 'squeeze ' ,
842- 'broadcast_to' , 'flip' , 'roll' , 'nonzero' , 'where' , 'reshape' ,
854+ 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'cumulative_prod' , 'sort' ,
855+ 'argsort' , 'prod' , 'sum' , ' any' , 'all' , 'mean' , 'std' , 'var' , 'concat' ,
856+ 'squeeze' , ' broadcast_to' , 'flip' , 'roll' , 'nonzero' , 'where' , 'reshape' ,
843857 'arange' , 'eye' , 'linspace' , 'full' , 'ones' , 'zeros' , 'empty' ,
844858 'tril' , 'triu' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
845859 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
0 commit comments