This repo trains a micro-dLLM(~12.23M) based on mercury's training and inference approach with BPE tokens and generates text with iterative denoising.
- Forward corruption with
[MASK]tokens over random timestepst in [1..T] - Time-conditioned Transformer denoiser (
timestep_emb) - Full-sequence denoising objective (predict clean
x0from noisyx_t) - Reverse denoising at inference (
t = T -> 1, plus finalt=0pass) - Confidence-based remasking for iterative refinement
- MP4 trace output for denoising steps
- Tokenization: BPE tokenizer from
data/stories.txt+[MASK]token - Context length:
block_size = 256tokens - Diffusion steps:
T = 100 - Layers:
n_layer = 6 - Attention heads:
n_head = 6 - Embedding dimension:
n_embd = 384 - Head dimension:
head_dim = 64 - Parameters:
10,706,304(~10.71M, with currentvocab_size = 66) - Attention type: bidirectional (
is_causal=False) - Positional scheme: RoPE (precomputed rotary cos/sin buffers)
- Normalization: RMSNorm via
F.rms_norm - MLP: expansion
4 * n_embdwithrelu(x)^2nonlinearity - Timestep conditioning: learned embedding
Embedding(T + 1, n_embd)
Training setup:
- Batch size:
64 - Max iterations:
5000 - Optimizer:
AdamW - Learning rate:
3e-4 - Forward noising: cosine survival schedule
a_t = cos((t / T) * pi / 2) - Objective: predict clean sequence
x0from noisyxt(cross-entropy on masked positions)
Inference setup:
- Reverse denoising loop:
t = T -> 1 - Per-step decoding: multinomial sampling (
temperature > 0) or greedy (temperature = 0) - Iterative refinement: low-confidence generated tokens are re-masked each step
- Final explicit denoise at
t=0with greedy selection - Visualization outputs: MP4 timeline of decoding steps
At the end of training, also prints a final validation metrics block with:
Perplexity(derived from masked validation cross-entropy)Masked reconstruction accuracy(accuracy on corrupted positions only)Entropy per timestep(masked-token predictive entropy across diffusion timesteps)Reverse-step token change rate(fraction of generated tokens that change between reverse steps)Distinct-2 diversity(unique generated bigrams / total generated bigrams, prompt excluded)
- Root:
train.py: model + traininginference.py: checkpoint loading + generation/trace exportREADME.md,learning.md,requirements.txt
scripts/: dataset/tokenizer preparation scriptsdata/: local corpus files (stories.txt, etc.)utils/: shared utility modulesartifacts/: checkpoints, tokenizer JSON, plots, and media outputs
Requirements:
- Python 3
torchPillow(for frame rendering)ffmpeg(for MP4 export)tokenizers(for BPE tokenization)
python3 scripts/data.py --num-stories 10000 --output data/stories.txt
python3 scripts/train_tokenizer.py --input data/stories.txt --output artifacts/tokenizer/tokenizer.json
python3 train.pyCheckpoints are saved to artifacts/models/ during training and at the end.
At the end of training, train.py also prints a final validation metrics block with:
Perplexity(derived from masked validation cross-entropy)Masked reconstruction accuracy(accuracy on corrupted positions only)Entropy per timestep(masked-token predictive entropy across diffusion timesteps)Reverse-step token change rate(fraction of generated tokens that change between reverse steps)Distinct-2 diversity(unique generated bigrams / total generated bigrams, prompt excluded)
python3 inference.py \
--checkpoint artifacts/models/model_stories_10k_bpe_256.pt \
--prompt "Once upon a time" \
--gen-len 256 \
--temperature 0.0 \
--viz-video artifacts/media/diffusion_trace.mp4 \
--trace-every 1 \
--gif-frame-ms 180Use scripts/data.py to stream TinyStories from Hugging Face and build a small local subset for laptop training:
python3 scripts/data.py \
--dataset roneneldan/TinyStories \
--split train \
--num-stories 10000 \
--seed 1337 \
--output data/stories.txtNotes:
--num-stories 100or--num-stories 200is a good range for quick local runs.- Sampling uses streaming + shuffle buffer, so it does not download the full dataset at once.
train.pyandinference.pyusedata/stories.txtandartifacts/tokenizer/tokenizer.json.
- Add resume training from checkpoint (
--resume model.pt) - Train on 100-200 Tiny Stories
- Train on 2k+ Stories
- Do loss curve ablations with gpt2 config for arm vs dif on 100-200 tiny stories
- Muon Ablations
- Train on SynTH/fineweb-edu
- Try speed running for a 200M+ param model
- Adding training with block based masking
- Using uniform diffusion instead of masked diffusion
