Educational project for learning encoderβdecoder Transformers through a simple algorithmic task. This small sequence-to-sequence Transformer is trained to map an unsorted integer sequence to sorted sequence.
β‘ This transformer model can be trained in a few minutes on a consumer-grade GPU. Let's give it a try!
Sort integer sequences by using a Transformer.
Input:
[5, 2, 8, 1, 9]
Output:
[BOS, 1, 2, 5, 8, 9, EOS]
- Vanilla encoderβdecoder Transformer (Pre-LN)
- Learned token embeddings + sinusoidal positional encodings
- Greedy decoding during evaluation
-
Integer sequences are generated synthetically at runtime
-
Token IDs are offset by
NUM_SPECIAL_TOKENS = 3to reserve:PAD = 0BOS = 1EOS = 2
-
Integer values are mapped to tokens
>= 3
For a sorted sequence [5, 2, 8]:
-
Encoder input: unsorted sequence e.g.
[5 + NUM_SPECIAL_TOKENS, 2 + NUM_SPECIAL_TOKENS, 8 + NUM_SPECIAL_TOKENS] = [8, 5, 11] -
Decoder input (teacher forcing):
[BOS] + [2 + NUM_SPECIAL_TOKENS, 5 + NUM_SPECIAL_TOKENS, 8 + NUM_SPECIAL_TOKENS] = [1, 5, 8, 11] -
Training targets (loss labels):
[2 + NUM_SPECIAL_TOKENS, 5 + NUM_SPECIAL_TOKENS, 8 + NUM_SPECIAL_TOKENS] + [EOS] = [5, 8, 11, 2]
Cross-entropy loss is used, with PAD tokens ignored.
Make sure you have uv installed on your system. If you don't have it yet, you can install it by following the instructions here.
uv sync --extra dev
uv pip install -e .Train:
uv run scripts/run_training.pyThis will train the model and save the checkpoints in the checkpoints/ directory.
You can then run the model interactively:
uv run scripts/play_model.py checkpoints/model_YYYYMMDD_HHMMSS.ptAll settings are defined in config.yaml:
data.*(sequence length, value range, dataset size, batch size)model.*(dimensions, heads, layers, dropout)training.*(optimizer, epochs, eval interval)seed(reproducibility)
.
βββ config.yaml # Training configuration and random seed
βββ checkpoints/ # Saved model checkpoints
βββ logs/ # Training and evaluation logs
βββ scripts/
β βββ run_training.py # Training loop with periodic evaluation
β βββ play_model.py # Interactive inference script
βββ transformer_sort/
β βββ models/
β β βββ *.py # Transformer model components
β βββ utils/
β βββ *.py # Dataset, masks, evaluation, seeding
βββ tests/
β βββ test_padding_equivalence.py # Padding equivalence tests
βββ pyproject.toml # Project metadata and tooling
βββ README.md # This document