Production-ready GPU optimization toolkit for the complete ML lifecycle: Pre-Training → Post-Training → Inference → Evaluation
┌─────────────────────────────────────────────────────────────────────────────┐
│ 🎯 Train 70B models on 24GB GPU │ ⚡ 3x faster inference │ 📉 4x memory │
└─────────────────────────────────────────────────────────────────────────────┘
| Optimization | Speedup | Memory | Best For |
|---|---|---|---|
| DDP/FSDP | Linear scaling | 3× reduction | Multi-GPU training |
| torch.compile | 1.5-2× | - | Any PyTorch model |
| Mixed Precision (AMP) | 2× | 50% less | Training |
| LoRA/QLoRA | - | 99.9% fewer params | Fine-tuning 70B+ |
| GaLore | - | 4× savings | Full-param training |
| NEFTune | - | - | +5-10% instruction following |
| ORPO/SimPO | - | 50% (no ref model) | Preference tuning |
| Model Merging | - | - | Combine model capabilities |
| Speculative Decoding | 2-3× | - | LLM inference |
| GPTQ/AWQ (4-bit) | - | 4× reduction | Production serving |
git clone https://github.com/your-org/gpu_optimize_demo.git
cd gpu_optimize_demo
pip install -e ".[all]"# ═══════════════════════════════════════════════════════════════════════════
# 🔥 DISTRIBUTED TRAINING — Train on multiple GPUs
# ═══════════════════════════════════════════════════════════════════════════
from src.training import DDPWrapper, DDPConfig, setup_distributed
setup_distributed() # Initialize distributed environment
# DDP: Simple multi-GPU (model fits in single GPU)
wrapper = DDPWrapper(DDPConfig(mixed_precision=True, precision="bf16"))
model = wrapper.wrap(model)
for batch in dataloader:
with wrapper.autocast():
loss = model(batch)
wrapper.backward(loss, optimizer, model)
# FSDP: Large models (7B+ parameters)
from src.training import FSDPWrapper, FSDPConfig
fsdp_model = FSDPWrapper(FSDPConfig(sharding_strategy="full_shard")).wrap(model)
# ═══════════════════════════════════════════════════════════════════════════
# 🎯 FINE-TUNING — Train 70B on single 24GB GPU
# ═══════════════════════════════════════════════════════════════════════════
from src.post_training import apply_lora, LoRAConfig
config = LoRAConfig(r=16, target_modules=["q_proj", "v_proj"])
model = apply_lora(model, config) # Only 0.1% parameters trainable!
# ═══════════════════════════════════════════════════════════════════════════
# ⚡ INFERENCE — 3x faster generation
# ═══════════════════════════════════════════════════════════════════════════
from src.inference import SpeculativeDecoder, quantize_model, QuantizationConfig
# Speculative Decoding: 2-3x speedup
decoder = SpeculativeDecoder(llama_70b, llama_7b)
output = decoder.generate(input_ids)
# 4-bit Quantization: 4x memory reduction
model = quantize_model(model, QuantizationConfig(bits=4, method="gptq"))
# ═══════════════════════════════════════════════════════════════════════════
# 🌐 HIGH-THROUGHPUT SERVING — Continuous Batching + FP8
# ═══════════════════════════════════════════════════════════════════════════
from src.serving import create_server
from src.inference import FP8KVCache, ChunkedPrefillScheduler
# Launch OpenAI-compatible server with optimizations
# - Continuous Batching: Interleave decode/prefill
# - FP8 Cache: 2x KV capacity
# - TP: 2-GPU Tensor Parallelism
create_server(
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
tp_size=2
)graph TD
%% Frontend Layer
subgraph "Frontend & API"
UI[Streamlit Chat UI] --> |HTTP| API[FastAPI Server]
API --> |OpenAI Protocol| Sched[Scheduler]
end
%% Serving Layer
subgraph "Serving Engine (Continuous Batching)"
Sched --> |Schedule| Batcher[Continuous Batcher]
Batcher --> |Chunked Prefill| Chunking[Chunked Prefill]
Batcher --> |Request| TP_Driver[TP Driver]
end
%% Inference Orchestration
subgraph "Inference Optimizations"
TP_Driver --> |Broadcast| W1[Worker 1] & W2[Worker 2]
subgraph "Worker Node"
Spec[Speculative Decoding] --> |Draft| DraftModel
Spec --> |Verify| TargetModel
subgraph "Memory Management"
Radix[Radix Attention / Prefix Cache]
PagedKV[Paged KV Cache]
FP8KV[FP8 KV Cache]
end
TargetModel --> Radix
TargetModel --> PagedKV
end
end
%% Kernel Layer
subgraph "High-Perf Kernels"
Triton[Triton Kernels]
Triton --> |FP8 Attention| FlashDec[Flash Decoding]
Triton --> |FP8 GEMM| GEMM[FP8 MatMul]
end
TargetModel --> Triton
%% Training Side (Context)
subgraph "Distributed Training"
FSDP[FSDP / DDP]
3D[3D Parallelism]
ZeRO[ZeRO++]
end
| Layer | Component | Key Features |
|---|---|---|
| Serving | OpenAI Server | FastAPI, SSE Streaming, Async Architecture |
| Scheduling | Continuous Batcher | Dynamic Batching, Chunked Prefill (Split-wise), Priority Queues |
| Inference | Speculative Decoding | Draft/Verify Logic, Eagle/Medusa support ready |
| Memory | Advanced KV Cache | Radix Attention (Prefix Caching), FP8 Compression, Paged Allocation |
| Distributed | Tensor Parallel | SPMD Worker Architecture, Distributed Broadcast |
| Kernels | Triton Kernels | FP8 Flash Decoding, FlashInfer Integration, Custom Fused Ops |
| Training | Advanced Training | 3D Parallelism (TP+SP+PP), ZeRO++ (Quantized Comm), Flash Attn v3 |
📝 Text version (for accessibility)
PRE-TRAINING: torch.compile → AMP (FP16/BF16) → DDP/FSDP → Memory Pool → Gradient Checkpointing
↓
POST-TRAINING:
├─ PEFT: LoRA/QLoRA, AdaLoRA, VeRA, GaLore
├─ Alignment: NEFTune, ORPO/SimPO, DPO/RLHF
└─ Advanced: Model Merging (TIES/DARE), Continual Learning (EWC)
↓
INFERENCE: Speculative Decoding → KV-Cache → Quantization (GPTQ/AWQ) → Continuous Batching
↓
SERVING & EVALUATION: SGLang Server → Benchmarks (MMLU) → Safety → Metrics → Latency Monitoring
Choose the right strategy for your model:
| Model Size | Strategy | Memory Savings | Use Case |
|---|---|---|---|
| < 7B | DDP | 1× (no sharding) | Simple multi-GPU |
| 7B - 70B | FSDP | ~3× | Single-node multi-GPU |
| > 70B | DeepSpeed | ~8×+ | Multi-node + CPU offload |
📖 DDP — Distributed Data Parallel
from src.training import DDPWrapper, DDPConfig, setup_distributed, cleanup_distributed
# Initialize
setup_distributed()
# Configure DDP
config = DDPConfig(
mixed_precision=True, # Enable AMP
precision="bf16", # BF16 more stable than FP16
gradient_clipping=1.0, # Prevent exploding gradients
gradient_accumulation_steps=4, # Effective batch = 4 × batch_size
static_graph=True, # Faster for fixed architectures
)
# Wrap model
wrapper = DDPWrapper(config)
model = wrapper.wrap(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Training loop
for epoch in range(num_epochs):
for batch in dataloader:
with wrapper.autocast():
loss = model(batch)
wrapper.backward(loss, optimizer, model) # Handles scaling & clipping
# Save checkpoint (rank 0 only)
wrapper.save_checkpoint(model, optimizer, epoch, "checkpoint.pt")
cleanup_distributed()Run:
torchrun --nproc_per_node=4 train.py📖 FSDP — Fully Sharded Data Parallel
from src.training import FSDPWrapper, FSDPConfig, DistributedTrainer
# FSDP shards parameters, gradients, and optimizer states across GPUs
config = FSDPConfig(
sharding_strategy="full_shard", # Maximum memory savings (~3×)
mixed_precision=True,
precision="bf16",
activation_checkpointing=True, # Further 40-60% memory savings
auto_wrap_policy="transformer", # Auto-wrap transformer blocks
cpu_offload=False, # Enable for extreme memory pressure
)
# Wrap model
wrapper = FSDPWrapper(config)
model = wrapper.wrap(model)
# Or use the unified trainer
trainer = DistributedTrainer(model, DistributedTrainerConfig(strategy="fsdp"))Run:
torchrun --nproc_per_node=8 train.py # 8 GPUs📖 Auto Strategy Selection
from src.training import auto_select_strategy, estimate_memory_usage
# Estimate memory requirements
estimates = estimate_memory_usage(model, batch_size=8)
print(f"Total memory needed: {estimates['total_gb']:.1f} GB")
# Auto-select best strategy
strategy = auto_select_strategy(
model,
available_gpus=torch.cuda.device_count(),
gpu_memory_gb=24.0,
)
print(f"Recommended strategy: {strategy}") # "ddp", "fsdp", or "deepspeed"🔧 torch.compile & CUDA Graphs
from src.compile import compile_model, CUDAGraphWrapper
# torch.compile — 30-200% speedup with one line
model = compile_model(model, mode="max-autotune")
# CUDA Graphs — 10-30% latency reduction for inference
wrapper = CUDAGraphWrapper(model, example_input)
output = wrapper(input) # Graph replay📉 Mixed Precision (AMP)
from src.training import AMPTrainer, AMPConfig
config = AMPConfig(dtype="bfloat16", enabled=True)
trainer = AMPTrainer(config)
for batch in dataloader:
with trainer.autocast():
loss = model(batch)
trainer.backward(loss)
trainer.step(optimizer)⚡ Flash Attention (v2/v3)
from src.training import FlashAttention, FlashAttentionConfig
# Drop-in replacement with 2-3x speedup and O(N) memory
config = FlashAttentionConfig(
use_flash_attn=True,
version="v3" # or "auto"
)
attn_layer = FlashAttention(config, hidden_size=4096)📦 Gradient Compression
from src.training import create_gradient_compressor
# Reduce communication bandwidth by 10-100x
compressor = create_gradient_compressor(
method="topk", # "topk", "random", "quantize", "powersgd"
ratio=0.01 # Keep top 1% gradients
)
# In training loop:
grad = compressor.compress_and_allreduce(param.grad)🚀 Optimized DataLoader
from src.training import create_prefetched_loader, DataLoaderConfig
# Prefetch data to GPU asynchronously (overlaps I/O with compute)
config = DataLoaderConfig(
batch_size=64,
prefetch_factor=4,
persistent_workers=True
)
loader = create_prefetched_loader(dataset, config, device='cuda')
for batch in loader:
# batch is already on GPU!
output = model(batch)🔥 Fused Operations
from src.training import FusedAdamW, FusedLayerNorm, FusedSwiGLU
# Faster kernels with lower memory bandwidth usage
optimizer = FusedAdamW(model.parameters(), lr=1e-4)
norm = FusedLayerNorm(4096)
mlp = FusedSwiGLU(4096, 11008)💾 Memory Optimization
from src.memory import MemoryPool, CPUOffloader, MemoryProfiler
# Memory Pool — Reuse tensor allocations
pool = MemoryPool()
tensor = pool.allocate((1024, 768), dtype=torch.float16)
# CPU Offload — 2× model capacity
offloader = CPUOffloader(model, optimizer)
# Profiling
profiler = MemoryProfiler()
profiler.snapshot("before")
output = model(x)
profiler.snapshot("after")
print(profiler.delta("before", "after"))🎯 LoRA / QLoRA / Advanced Variants
from src.post_training import LoRAModel, LoRAConfig, merge_lora
# LoRA reduces trainable params by 10,000×
config = LoRAConfig(
r=16, # Rank
alpha=32, # Scaling factor
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
dropout=0.05,
)
model = LoRAModel(base_model, config)
model.print_trainable_params() # "Trainable: 4,194,304 (0.01%)"
# Merge for deployment (zero overhead inference)
merged_model = merge_lora(model)Advanced Variants:
from src.post_training import (
VeRALayer, # 10× fewer params than LoRA
AdaLoRATrainer, # Adaptive rank allocation
LoRAXSLayer, # Ultra-low rank (r=1-2)
create_lora_plus_optimizer, # Different lr for A/B matrices
)
# VeRA: Shared random projections (most parameter efficient)
vera_layer = VeRALayer(base_linear, r=256)
# LoRA+: 16× higher lr for B matrix → faster convergence
optimizer = create_lora_plus_optimizer(model, lr=1e-4, lr_ratio=16.0)🤝 ORPO / SimPO — No Reference Model Needed
from src.post_training import ORPOTrainer, SimPOTrainer, ORPOConfig, SimPOConfig
# ORPO: Combines SFT + preference in single objective
orpo = ORPOTrainer(model, ORPOConfig(lambda_orpo=0.1))
for batch in data:
stats = orpo.step(batch["chosen"], batch["rejected"])
# SimPO: Length-normalized, no reference model
simpo = SimPOTrainer(model, SimPOConfig(beta=2.0, gamma=0.5))When to use:
| Method | Reference Model | Best For |
|---|---|---|
| DPO | ✅ Required | Standard preference tuning |
| ORPO | ❌ Not needed | Combined SFT + preference |
| SimPO | ❌ Not needed | High-quality preference data |
| IPO | ✅ Required | Noisy preference data |
📉 GaLore — Full-Parameter Training with LoRA Memory
from src.post_training import create_galore_optimizer, estimate_galore_memory_savings
# Estimate savings
savings = estimate_galore_memory_savings(model, rank=128)
print(f"Memory Savings: {savings['memory_savings_ratio']:.1f}x")
# Create optimizer with gradient projection
optimizer = create_galore_optimizer(
model,
lr=1e-4,
rank=128, # Projection rank
update_proj_gap=200, # Steps between projection updates
target_modules=["q_proj", "k_proj", "v_proj"],
)GaLore vs LoRA:
- LoRA freezes base model, adds low-rank adapters
- GaLore trains full model but projects gradients to low-rank
- Result: Full model updates with LoRA-level memory
🔊 NEFTune — Noisy Embeddings
from src.post_training import apply_neftune
# Simple: Add noise to embeddings during training
# Improves instruction following by 5-10%
trainer = apply_neftune(model, noise_alpha=5.0)
# Training loop (noise automatically added)
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
# Disable for evaluation
trainer.disable()🔀 Model Merging (TIES, DARE, SLERP)
from src.post_training import ModelMerger, ties_merge, dare_merge, slerp_merge
# SLERP: Smooth interpolation between two models
merged = slerp_merge(model_a, model_b, t=0.5)
# TIES: Prune conflicts, keep important deltas
merged = ties_merge(base_model, [model_a, model_b], threshold=0.2)
# DARE: Randomly drop 90% of deltas → surprisingly effective
merged = dare_merge(base_model, [model_a, model_b], drop_rate=0.9)
# Full control
merger = ModelMerger(base_model)
merged = merger.merge_models(
[model_a, model_b, model_c],
config=MergeConfig(method="ties", weights=[0.5, 0.3, 0.2]),
)📚 Continual Learning (Anti-Forgetting)
from src.post_training import EWCRegularizer, ReplayBuffer, ContinualLearner
# EWC: Elastic Weight Consolidation
ewc = EWCRegularizer(model, lambda_=1000.0)
# Train task 1...
ewc.compute_fisher(task1_dataloader)
ewc.consolidate()
# Train task 2 with EWC penalty
for batch in task2_dataloader:
loss = criterion(model(batch)) + ewc.penalty()
loss.backward()
# Replay Buffer: Mix old samples with new
buffer = ReplayBuffer(max_size=10000)
buffer.add(task1_samples)
replay = buffer.sample(batch_size=32) # Mix with task2⚡ Speculative Decoding
from src.inference import SpeculativeDecoder, SpeculativeConfig
# Use small draft model to propose, large model to verify
decoder = SpeculativeDecoder(
target_model=llama_70b,
draft_model=llama_7b,
config=SpeculativeConfig(num_speculative_tokens=5),
)
output = decoder.generate(input_ids, max_new_tokens=256)
decoder.print_stats()
# Accept rate: 78%
# Speedup: 2.3×📉 Quantization (4-bit/8-bit)
from src.inference import quantize_model, QuantizationConfig
# GPTQ or AWQ 4-bit quantization
config = QuantizationConfig(
bits=4,
method="gptq", # or "awq"
group_size=128,
)
quantized = quantize_model(model, config, calibration_data)
# Memory: 70B model → ~35GB → ~9GB📦 Continuous Batching & KV-Cache
from src.inference import ContinuousBatcher, PagedKVCache
# Continuous batching for high throughput
batcher = ContinuousBatcher(model, tokenizer, max_batch_size=32)
batcher.start()
# Paged KV-Cache (vLLM-style)
cache = PagedKVCache(config)🤖 SGLang Server
from src.sglang_inference import SGLangServer, ServerConfig
server = SGLangServer(ServerConfig(
model_path="meta-llama/Llama-2-70b-chat-hf",
tp_size=4, # Tensor parallel across 4 GPUs
port=30000,
)).start()📊 Benchmarks & Safety
from src.evaluation import run_benchmark, SafetyEvaluator
# MMLU benchmark
result = run_benchmark(model, tokenizer, "mmlu")
print(f"MMLU Accuracy: {result.accuracy:.2%}")
# Safety evaluation
evaluator = SafetyEvaluator()
report = evaluator.evaluate_model(model, tokenizer, test_prompts)gpu_optimize_demo/
├── src/
│ ├── compile/ # torch.compile, CUDA Graphs, TensorRT
│ ├── training/ # AMP, DDP, FSDP, DeepSpeed, Gradients
│ ├── memory/ # Memory Pool, Offloading, Profiling
│ ├── post_training/ # LoRA, PEFT, RLHF, DPO
│ ├── inference/ # Speculative, Batching, PagedKV, TP-Worker
│ ├── serving/ # OpenAI Server, Request Scheduler
│ ├── evaluation/ # Benchmarks (MMLU), Safety, Metrics
│ ├── triton_kernels/ # Custom Triton Kernels (FP8, FlashDec)
│ ├── sglang_inference/ # SGLang Integration
│ ├── ray_distributed/ # Ray Train, Tune, Serve
│ ├── profiling/ # Torch Profiler, CUDA Timer
│ ├── io_optimize/ # DataLoader, Prefetcher
│ └── nccl/ # Communication Profiler
│
├── examples/ # Ready-to-run examples
│ ├── advanced_inference.py # Speculative, Chunking, FP8, TP
│ ├── eval_serving_mmlu.py # End-to-end Serving + MMLU Eval
│ ├── chat_ui.py # Streamlit Chat Frontend
│ ├── ddp_fsdp_training.py # Distributed training
│ ├── lora_finetuning.py # LoRA/QLoRA
│ └── ...
│
├── configs/ # Configuration files
└── benchmarks/ # Benchmark suites
# Distributed training
torchrun --nproc_per_node=4 examples/ddp_fsdp_training.py --strategy fsdp
# Training optimization
python examples/training_optimization.py
# Fine-tuning
python examples/lora_finetuning.py
# Inference
python examples/speculative_decoding.py
python examples/quantization.py
# Evaluation
python examples/run_benchmarks.py| Package | Version | Purpose |
|---|---|---|
| PyTorch | ≥2.0.0 | Core framework |
| Triton | ≥2.1.0 | Custom GPU kernels |
| CUDA | ≥12.0 | GPU acceleration |
| Transformers | ≥4.35.0 | Model support |
📦 Full requirements
torch>=2.0.0
triton>=2.1.0
ray[all]>=2.9.0
sglang[all]>=0.2.0
numpy>=1.24.0
transformers>=4.35.0
accelerate>=0.25.0
bitsandbytes>=0.41.0
tensorboard>=2.14.0
Contributions are welcome! Please read our Contributing Guide.
MIT License - see LICENSE for details.
📖 Docs · 💻 Examples · 📊 Benchmarks · 🐛 Issues
Built for ML engineers who need production-ready GPU optimizations 🚀
