Skip to content

Commit 85cf228

Browse files
committed
FIX: Wrap torch.argsort to set stable=True by default
1 parent 5938c3f commit 85cf228

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,21 @@ def sort(
241241
) -> Array:
242242
return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values
243243

244+
245+
# Wrap torch.argsort to set stable=True by default
246+
def argsort(
247+
x: Array,
248+
/,
249+
*,
250+
axis: int = -1,
251+
descending: bool = False,
252+
stable: bool = True,
253+
**kwargs: object,
254+
) -> Array:
255+
256+
return torch.argsort(x, dim=axis, descending=descending, stable=stable, **kwargs)
257+
258+
244259
def _normalize_axes(axis, ndim):
245260
axes = []
246261
if ndim == 0 and axis:

0 commit comments

Comments
 (0)