Fast, differentiable sorting and ranking in PyTorch.
Pure PyTorch implementation of Fast Differentiable Sorting and Ranking (Blondel et al.). Much of the code is copied from the original Numpy implementation at google-research/fast-soft-sort, with the isotonic regression solver rewritten as a PyTorch C++ and CUDA extension.
pip install torchsortTo build the CUDA extension you will need the CUDA toolchain installed. If you
want to build in an environment without a CUDA runtime (e.g. docker), you will
need to export the environment variable
TORCH_CUDA_ARCH_LIST="Pascal;Volta;Turing;Ampere" before installing.
Conda Installation
On some systems the package my not compile with `pip` install in conda environments. If this happens you may need to:- Install g++ with conda install -c conda-forge gxx_linux-64=9.40
- Run export CXX=/path/to/miniconda3/envs/env_name/bin/x86_64-conda_cos6-linux-gnu-g++
- Run export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/miniconda3/lib
- pip install --force-reinstall --no-cache-dir --no-deps torchsort
Thanks to @levnikmyskin, @sachit-menon for pointing this out!
Pre-built wheels are currently available on Linux for recent Python/PyTorch/CUDA combinations:
# torchsort version, supports >= 0.1.10
export TORCHSORT=0.1.10
# PyTorch version, supports pt26, pt25, pt24, pt21, pt20, and pt113 for versions
# 2.6, 2.5, 2.4, 2.1, 2.0, and 1.13 respectively
export TORCH=pt26
# CUDA version, supports cpu, cu113, cu117, cu118, cu121, cu124, and cu126 for
# CPU-only, CUDA 11.3, CUDA 11.7, CUDA 11.8, CUDA 12.1, CUDA 12.4, and CUDA 12.6
# respectively
export CUDA=cu126
# Python version, supports cp310, cp311, and cp312 for versions 3.10, 3.11, and
# 3.12 respectively
export PYTHON=cp312
pip install https://github.com/teddykoker/torchsort/releases/download/v${TORCHSORT}/torchsort-${TORCHSORT}+${TORCH}${CUDA}-${PYTHON}-${PYTHON}-linux_x86_64.whlThanks to siddharthab for the help creating the build action! See the latest release for a list of supported combinations in Assets.
torchsort exposes two functions: soft_rank and soft_sort, each with
parameters regularization ("l2" or "kl") and regularization_strength (a
scalar value). Each will rank/sort the last dimension of a 2-d tensor, with an
accuracy dependent upon the regularization strength:
import torch
import torchsort
x = torch.tensor([[8, 0, 5, 3, 2, 1, 6, 7, 9]])
torchsort.soft_sort(x, regularization_strength=1.0)
# tensor([[0.5556, 1.5556, 2.5556, 3.5556, 4.5556, 5.5556, 6.5556, 7.5556, 8.5556]])
torchsort.soft_sort(x, regularization_strength=0.1)
# tensor([[-0., 1., 2., 3., 5., 6., 7., 8., 9.]])
torchsort.soft_rank(x)
# tensor([[8., 1., 5., 4., 3., 2., 6., 7., 9.]])Both operations are fully differentiable, on CPU or GPU:
x = torch.tensor([[8., 0., 5., 3., 2., 1., 6., 7., 9.]], requires_grad=True).cuda()
y = torchsort.soft_sort(x)
torch.autograd.grad(y[0, 0], x)
# (tensor([[0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111]],
#         device='cuda:0'),)Spearman's rank coefficient is a very useful metric for measuring how monotonically related two variables are. We can use Torchsort to create a differentiable Spearman's rank coefficient function so that we can optimize a model directly for this metric:
import torch
import torchsort
def spearmanr(pred, target, **kw):
    pred = torchsort.soft_rank(pred, **kw)
    target = torchsort.soft_rank(target, **kw)
    pred = pred - pred.mean()
    pred = pred / pred.norm()
    target = target - target.mean()
    target = target / target.norm()
    return (pred * target).sum()
pred = torch.tensor([[1., 2., 3., 4., 5.]], requires_grad=True)
target = torch.tensor([[5., 6., 7., 8., 7.]])
spearman = spearmanr(pred, target)
# tensor(0.8321)
torch.autograd.grad(spearman, pred)
# (tensor([[-5.5470e-02,  2.9802e-09,  5.5470e-02,  1.1094e-01, -1.1094e-01]]),)torchsort and fast_soft_sort each operate with a time complexity of O(n log
n), each with some additional overhead when compared to the built-in
torch.sort. With a batch size of 1 (see left), the Numba JIT'd forward pass of
fast_soft_sort performs about on-par with the torchsort CPU kernel, however
its backward pass still relies on some Python code, which greatly penalizes its
performance.
Furthermore, the torchsort kernel supports batches, and yields much better
performance than fast_soft_sort as the batch size increases.
The torchsort CUDA kernel performs quite well with sequence lengths under
~2000, and scales to extremely large batch sizes. In the future the
CUDA kernel can likely be further optimized to achieve performance closer to that of the
built in torch.sort.
@inproceedings{blondel2020fast,
  title={Fast differentiable sorting and ranking},
  author={Blondel, Mathieu and Teboul, Olivier and Berthet, Quentin and Djolonga, Josip},
  booktitle={International Conference on Machine Learning},
  pages={950--959},
  year={2020},
  organization={PMLR}
}
