Skip to content

yhisaki/transformer-sort

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

10 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

🧠 Transformer Sort πŸ”’

Educational project for learning encoder–decoder Transformers through a simple algorithmic task. This small sequence-to-sequence Transformer is trained to map an unsorted integer sequence to sorted sequence.

⚑ This transformer model can be trained in a few minutes on a consumer-grade GPU. Let's give it a try!

🎯 Task

Sort integer sequences by using a Transformer.

Input:

[5, 2, 8, 1, 9]

Output:

[BOS, 1, 2, 5, 8, 9, EOS]

πŸ—οΈ Model and Data

  • Vanilla encoder–decoder Transformer (Pre-LN)
  • Learned token embeddings + sinusoidal positional encodings
  • Greedy decoding during evaluation

πŸ”’ Data generation and tokenization

  • Integer sequences are generated synthetically at runtime

  • Token IDs are offset by NUM_SPECIAL_TOKENS = 3 to reserve:

    • PAD = 0
    • BOS = 1
    • EOS = 2
  • Integer values are mapped to tokens >= 3

πŸ”— Sequence construction

For a sorted sequence [5, 2, 8]:

  • Encoder input: unsorted sequence e.g.

    [5 + NUM_SPECIAL_TOKENS, 2 + NUM_SPECIAL_TOKENS, 8 + NUM_SPECIAL_TOKENS] = [8, 5, 11]
    
  • Decoder input (teacher forcing):

    [BOS] + [2 + NUM_SPECIAL_TOKENS, 5 + NUM_SPECIAL_TOKENS, 8 + NUM_SPECIAL_TOKENS] = [1, 5, 8, 11]
    
  • Training targets (loss labels):

    [2 + NUM_SPECIAL_TOKENS, 5 + NUM_SPECIAL_TOKENS, 8 + NUM_SPECIAL_TOKENS] + [EOS] = [5, 8, 11, 2]
    

Cross-entropy loss is used, with PAD tokens ignored.

πŸ“¦ Installation

Make sure you have uv installed on your system. If you don't have it yet, you can install it by following the instructions here.

uv sync --extra dev
uv pip install -e .

πŸš€ Quick Start

Train:

uv run scripts/run_training.py

This will train the model and save the checkpoints in the checkpoints/ directory.

You can then run the model interactively:

uv run scripts/play_model.py checkpoints/model_YYYYMMDD_HHMMSS.pt

βš™οΈ Configuration

All settings are defined in config.yaml:

  • data.* (sequence length, value range, dataset size, batch size)
  • model.* (dimensions, heads, layers, dropout)
  • training.* (optimizer, epochs, eval interval)
  • seed (reproducibility)

πŸ“ Project Layout

.
β”œβ”€β”€ config.yaml                         # Training configuration and random seed
β”œβ”€β”€ checkpoints/                        # Saved model checkpoints
β”œβ”€β”€ logs/                               # Training and evaluation logs
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ run_training.py                 # Training loop with periodic evaluation
β”‚   └── play_model.py                   # Interactive inference script
β”œβ”€β”€ transformer_sort/
β”‚   β”œβ”€β”€ models/
β”‚   β”‚   └── *.py                        # Transformer model components
β”‚   └── utils/
β”‚       └── *.py                        # Dataset, masks, evaluation, seeding
β”œβ”€β”€ tests/
β”‚   └── test_padding_equivalence.py     # Padding equivalence tests
β”œβ”€β”€ pyproject.toml                      # Project metadata and tooling
└── README.md                           # This document

About

🧠 A tiny Transformer learns to sort numbers. This transformer model can be trained in a few minutes on a consumer-grade GPU. Let’s give it a try!

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages