This directory contains simple examples demonstrating the poet_torch API.
from poet_torch import POETConfig, POETModel, get_poet_optimizer
# 1. Create config
config = POETConfig(
block_size=256, # POET Block size
merge_interval=200, # Steps between merge-then-reinitialize
)
# 2. Wrap your model
model = POETModel(your_model, config)
# 3. Create optimizer
optimizer = get_poet_optimizer(model, config)
# 4. Training loop
for step, batch in enumerate(dataloader):
loss = model(**batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.merge_if_needed(step) # Automatic merge- Using
POETConfigfor configuration - Wrapping a model with
POETModel - Using
get_poet_optimizer()for automatic optimizer setup - Automatic merge with
model.merge_if_needed(step)
from poet_torch import POETConfig, QPOETConfig
# Standard POET
config = POETConfig(
block_size=256, # Block size for transformations
merge_interval=200, # Steps between merge operations
poet_lr=5e-4, # Learning rate for POET params
base_lr=1e-3, # Learning rate for base params
mem_efficient_mode=False, # Memory-efficient mode
)
# Quantized POET
qconfig = QPOETConfig(
block_size=256,
merge_interval=200,
weight_bits=8, # INT8 quantization
weight_group_size=256,
)from poet_torch import POETModel
# Wrap any PyTorch model
model = POETModel(base_model, config)
# Access model info
model.print_model_info()
info = model.get_model_info()
# Get effective weights
weights = model.get_effective_weights()from poet_torch import get_poet_optimizer, POETAdamW
# Automatic optimizer setup
optimizer_auto = get_poet_optimizer(model, config)
# Manual optimizer setup
poet_params = get_poet_params(model)
base_params = ...
param_groups = [
dict(params=base_params, weight_decay=0.0, lr=1e-3, use_poet=False),
dict(params=poet_params, weight_decay=0.0, lr=5e-4,
use_poet=True, poet_merge_interval=200, poet_scale=0.5),
]
optimizer_manual = POETAdamW(param_groups, lr=1e-3)# Automatic merge check
merged = model.merge_if_needed(step)
# Force immediate merge
model.merge()
# Manual merge (for custom training loops)
from poet_torch import merge_and_reinitialize
merged = merge_and_reinitialize(model, step, merge_interval=200)- Manual layer replacement with
replace_linear_with_poet() - Custom parameter groups
- Memory-efficient mode (
mem_efficient_mode=True) - Targeted module replacement
- Manual merge with
merge_and_reinitialize() - ...
A minimal single-GPU training example using the POETModel wrapper:
cd examples
python 1_toy.py --model_config configs/llama_250m.jsontorchrun_main.py is a full-featured pretraining script supporting distributed training with DDP. It supports multiple optimizers including POET and QPOET.
- Download and prepare the C4 dataset in
./c4/en/directory - Install dependencies:
pip install torch transformers datasets loguru tqdm
cd examples
torchrun --standalone --nproc_per_node 8 torchrun_main.py \
--model_config configs/llama_3b.json \
--lr 0.001 \
--batch_size 64 \
--total_batch_size 512 \
--max_length 256 \
--num_training_steps 600000 \
--warmup_steps 5000 \
--min_lr_ratio 0.01 \
--weight_decay 0.0 \
--grad_clipping 0.1 \
--dtype bfloat16 \
--eval_every 20000 \
--save_every 10000000 \
--optimizer poet \
--poet_lr 0.0005 \
--poet_merge_interval 200 \
--poet_block_size 512cd examples
torchrun --standalone --nproc_per_node 8 torchrun_main.py \
--model_config configs/llama_3b.json \
--lr 0.001 \
--batch_size 64 \
--total_batch_size 256 \
--max_length 256 \
--num_training_steps 100000 \
--warmup_steps 5000 \
--min_lr_ratio 0.01 \
--weight_decay 0.01 \
--grad_clipping 1.0 \
--dtype bfloat16 \
--eval_every 5000 \
--save_every 10000000 \
--optimizer q_poet \
--poet_lr 0.0005 \
--poet_merge_interval 400 \
--poet_block_size 256 \
--gd_warmup_steps 10000- POET requires dimensions to be divisible by
block_size - The
lm_headlayer is excluded by default - Use
mem_efficient_mode=Truefor large models to save memory - QPOET uses INT8 quantization for additional memory savings