Simple transformer implementation from scratch in PyTorch. This repository is forked from https://codeberg.org/pbm/former and the intent is to experiment the transformer architecture in PyTorch and compare it to a LSTM with attention baseline for text generation tasks.
See http://peterbloem.nl/blog/transformers for an in-depth explanation.
All models in the repository consist of a single stack of transformer blocks (that is, no encoder/decoder structures). It turns out that this simple configuration often works best.
Install uv (if not already installed):
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/envAfter cloning the repository, install the dependencies like this
uv sync --extra mlRun tests to verify installation:
uv run python -m unittest discover testsClassification (IMDb sentiment analysis):
uv run python experiments/classify.pyText Generation (enwik8, character-level):
uv run python experiments/generate.pyTransformer with BPE (C4 dataset):
# Quick test with small config
uv run python -m experiments.seqmodels.generate_custom --config experiments/seqmodels/configs/transformer_test.yaml
# Full training
uv run python -m experiments.seqmodels.generate_custom --config experiments/seqmodels/configs/transformer.yamlLSTM with BPE (C4 dataset):
# Quick test
uv run python -m experiments.seqmodels.generate_lstm --config experiments/seqmodels/configs/lstm_test.yaml
# Full training
uv run python -m experiments.seqmodels.generate_lstm --config experiments/seqmodels/configs/lstm.yamlUse --help to see available options for any script.
Compare Model Performance:
# Evaluate both models on the same test samples
uv run python -m experiments.seqmodels.evaluate \
--transformer transformer_checkpoints/run_XXXXXXXX_XXXXXX/ \
--lstm lstm_checkpoints/run_XXXXXXXX_XXXXXX/ \
--batch-size 256 \
--num-batches 100 \
--transformer-checkpoint checkpoint_batch_6000.pt \
--lstm-checkpoint checkpoint_batch_6000.ptGenerate Text from Checkpoint:
# Generate with default settings
uv run python -m experiments.seqmodels.generate_from_checkpoint \
transformer_checkpoints/run_XXXXXXXX_XXXXXX/
# Custom prompt and parameters
uv run python -m experiments.seqmodels.generate_from_checkpoint \
transformer_checkpoints/run_XXXXXXXX_XXXXXX/ \
--prompt "Once upon a time" \
--length 200 \
--temperature 0.9 \
--checkpoint checkpoint_batch_6000.pt