The best of all worlds — Synchronous stability + Asynchronous efficiency + Native simplicity
Installation • Quick Start • Features • Architecture • Documentation
Flux is a flexible and efficient reinforcement learning framework for LLM post-training (RLHF). Unlike existing frameworks that force a binary choice between synchronous stability and asynchronous throughput, Flux adaptively adjusts the sync/async ratio in real-time based on measured staleness.
Key Insight: Sync vs Async is NOT a binary choice. Flux operates anywhere on this spectrum, adapting in real-time to maximize both stability and throughput.
Sync ◄────────────────────────────────────────────────────► Async
VERL ████████████░░░░░░░░░░░░░░░░░░ Stable but slow
AReaL ░░░░░░░░░░░░░░░░░░████████████ Fast but risky
Flux ◄═══════ adapts here ═══════► Best of both
| Aspect | VERL | AReaL | Slime | Flux |
|---|---|---|---|---|
| Sync Strategy | Fixed sync | Fixed async | Both modes | Adaptive |
| Orchestration | Ray | Custom | HTTP | asyncio |
| Training Backend | Megatron/FSDP | Custom | Megatron | Pluggable (Transformers, Megatron) |
| Inference Backend | vLLM/SGLang | Custom | SGLang | SGLang |
| Weight Sync | Ray Object Store | Custom | CUDA IPC | CUDA IPC + HTTP |
| Staleness Handling | N/A | Staleness-aware PPO | APRIL | Unified |
| Code Complexity | ~15k LOC | ~25k LOC | ~8k LOC | <5k LOC† |
† Core runtime only; excludes third-party integrations (Megatron, SGLang) and test code.
Framework Notes:
- VERL (ByteDance/HKU): HybridFlow controller with Ray-based orchestration, stable but has GPU bubble time
- AReaL (Ant Group/Tsinghua): Fully async with importance weight correction, high throughput but potential instability
- Slime (THUDM/Zhipu): SGLang-native with APRIL strategy, simple but less flexible dataflow
Flux uses a PID controller to dynamically adjust the sync/async ratio based on measured staleness.
The async_ratio (float in [0.1, 0.9]) controls the maximum proportion of off-policy data allowed in training batches:
| async_ratio | Behavior |
|---|---|
| 0.1 (more sync) | Sync barrier triggered frequently; training waits for fresh rollouts |
| 0.5 (balanced) | Mixed fresh/stale data; moderate overlap between rollout and training |
| 0.9 (more async) | Training proceeds with older data; rollouts run independently |
The controller output affects three subsystems:
- Sync Policy: When
staleness > target + tolerance, trigger a sync barrier (wait for all in-flight rollouts) - Buffer Policy: Capacity limit =
(max_version_gap + current_version + 1) × batch_size - in_flight - Batch Policy: Stratified sampling balances fresh vs stale trajectories based on
async_ratio
┌─────────────────────────────────────────────────────────────────────────┐
│ Flux Control Loop │
└─────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ │
│ 1. ROLLOUT GENERATION │
│ SGLang generates responses → Trajectory Buffer │
│ (tracks: behavior_logprobs, policy_version) │
│ │
└──────────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 2. STALENESS MEASUREMENT │
│ For each batch, compute: │
│ • KL divergence (π_current vs π_behavior) │
│ • Importance weight variance │
│ • Version gap (current_version - trajectory_version) │
│ → combined_staleness = weighted sum (see formula below) │
└──────────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 3. PID CONTROLLER UPDATE │
│ error = target_staleness - EMA(staleness) │
│ output = kp*error + ki*∫error + kd*d(error)/dt │
│ async_ratio = clip(async_ratio + output, [0.1, 0.9]) │
└──────────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 4. SYNC DECISION │
│ should_sync = (staleness > threshold) OR │
│ (steps_since_sync > max_steps) OR │
│ (capacity_remaining <= 0) │
│ │
│ If should_sync: │
│ → Wait for rollouts, sync weights to SGLang │
│ Else: │
│ → Continue async, submit more rollouts │
└──────────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 5. BATCH COMPOSITION │
│ Smart composer creates batch: │
│ • Length bucketing (minimize padding) │
│ • Staleness balancing (stratified by version_gap) │
│ • Curriculum ordering (easy→hard as training proceeds) │
└──────────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 6. TRAINING STEP │
│ • Apply importance correction weights │
│ • Compute policy loss (PPO/GRPO/etc.) │
│ • Gradient step, increment policy_version │
│ • Mark weights dirty → lazy sync to SGLang │
└──────────────────────────┬──────────────────────────────────────┘
│
└──────────────► Loop back to step 1
Flux monitors three staleness signals to quantify how much the policy has drifted since data was generated:
| Signal | Formula | Computed Over | Unit | Danger Threshold |
|---|---|---|---|---|
| KL Divergence | D_KL(π_behavior ‖ π_current) = E_behavior[log π_behavior - log π_current] |
Token-level, averaged over response tokens | nats | > 0.1 |
| IW Variance | Var(w) where w = exp(mean_token(log π_current - log π_behavior)) |
Per-trajectory, then variance across batch | unitless | > 2.0 |
| Version Gap | current_version - trajectory_version |
Per-trajectory, averaged over batch | integer (steps) | > 5 |
# Normalize each component to [0, 1]
kl_contrib = min(1, kl_divergence / kl_normalizer) # default kl_normalizer = 0.1
iw_contrib = min(1, iw_variance / iw_normalizer) # default iw_normalizer = 2.0
version_contrib = min(1, version_gap / max_version_gap) # default max_version_gap = 5
# Weighted sum (heuristic defaults, can be tuned)
combined_staleness = 0.4 * kl_contrib + 0.3 * iw_contrib + 0.3 * version_contribThe combined staleness is smoothed via EMA (α=0.1) before feeding into the PID controller.
Flux provides algorithm-agnostic off-policy correction that works with any RL algorithm. This corrects for the distribution shift between the behavior policy (which generated the data) and the current policy (which is being trained).
# Per-trajectory importance weight
log_ratio = current_logprobs - behavior_logprobs # [batch, seq_len]
mean_log_ratio = (log_ratio * mask).sum(dim=-1) / seq_len # Average over tokens
base_weight = exp(clip(mean_log_ratio, -20, 20)) # [batch]
# Staleness decay (older data gets lower weight)
staleness_weight = staleness_decay ** version_gap # default decay = 0.99
# Final importance weight
importance_weight = base_weight * staleness_weight
importance_weight = clip(importance_weight, min=0.2, max=5.0)
# Normalize to preserve gradient scale
importance_weight = importance_weight * (batch_size / sum(importance_weight))| Algorithm | Native Clip | Importance Correction | Interaction |
|---|---|---|---|
| PPO | clip(ratio, 1-ε, 1+ε) |
Per-sample weight | PPO clip applies to ratio, IW multiplies the clipped loss. No conflict—IW adjusts magnitude, PPO clip bounds direction. |
| GRPO | Group normalization | Per-sample weight | IW applied before group normalization; stale samples contribute less to group baseline. |
| DPO | N/A | Per-pair weight | IW downweights stale preference pairs; fresh pairs dominate loss. |
| REINFORCE | N/A | Per-sample weight | Direct multiplication with policy gradient. |
For on-policy algorithms (PPO, GRPO), importance correction is optional but beneficial when using async training. For truly on-policy training (async_ratio → 0), IW approaches 1.0 for all samples.
Active Partial Rollout for eIfficient generation with Long-tail handling:
| Step | Description | Parameters | Behavior |
|---|---|---|---|
| Oversample | Generate more prompts than needed | oversample_ratio=1.5 |
Submit 1.5× prompts to buffer against long-tail |
| Abort | Cancel long-running generations | batch_timeout=30s |
After timeout, abort remaining and yield completed |
| Reuse | Save partial trajectories | partial_reuse_threshold=0.5 |
If >50% tokens generated, save prefix for continuation |
Abort Rule: A generation is aborted when:
elapsed_time > batch_timeout, ANDcompleted_count >= target_count
Partial Reuse Consistency:
- Partial trajectories store:
(prompt, generated_prefix, behavior_logprobs, version) - On continuation: new tokens use current policy, prefix tokens keep original logprobs
- Version annotation:
trajectory.version_segments = [(0, prefix_len, old_version), (prefix_len, total_len, new_version)] - Consistency weighting: each segment weighted by
staleness_decay ** segment_version_gap
Buffer Management:
- Partial buffer max size:
2 × batch_size(FIFO eviction) - Oversample pool max size:
oversample_ratio × prompt_queue_size
class SmartBatchComposer:
def compose_batches(
self,
trajectories: List[Trajectory],
current_version: int,
) -> Iterator[TrajectoryBatch]:
"""
Yields batches optimized for:
1. Minimal padding (length bucketing)
2. Balanced staleness (stratified sampling)
3. Curriculum progression (difficulty ordering)
"""| Bucket | Length Range | Purpose |
|---|---|---|
| short | 0-512 tokens | Fast iteration, minimal padding |
| medium | 512-1024 tokens | Balanced |
| long | 1024-2048 tokens | Complex reasoning |
| very_long | 2048+ tokens | Full context tasks |
Trajectories are grouped by bucket; batches drawn from single bucket to minimize padding waste.
Stratified sampling by version_gap buckets:
strata = [
{"version_gap": 0, "trajectories": [...]}, # Fresh (on-policy)
{"version_gap": 1, "trajectories": [...]}, # 1 step stale
{"version_gap": 2, "trajectories": [...]}, # 2 steps stale
{"version_gap": 3+, "trajectories": [...]}, # 3+ steps stale
]
# Sample proportionally from each stratum to balance batchDifficulty signal sources (configurable):
- Length:
difficulty = min(1, total_length / 2048) - Reward:
difficulty = 1 - (reward + 1) / 2(lower reward = harder) - Custom:
composer.set_difficulty_function(my_fn)
Ordering with decaying randomness:
- Early training:
randomness = 1.0(fully shuffled) - Late training:
randomness → 0(strict easy→hard ordering) - Decay:
randomness = 1 / (1 + decay_rate × curriculum_step)
| Deployment | Method | Mechanism | Latency |
|---|---|---|---|
| Same node (colocated) | CUDA IPC | Zero-copy GPU-to-GPU via IPC handles | ~10ms |
| Same node (separate process) | CUDA IPC | Shared memory handles | ~10ms |
| Cross-node | NCCL broadcast | Distributed all-reduce pattern | ~100ms |
| Cross-node (fallback) | HTTP + torch.save | Serialized state_dict transfer | ~1s+ |
Current Implementation Status:
- ✅ Same-node CUDA IPC (implemented)
- ✅ HTTP-based state_dict sync (implemented)
- ✅ Delta compression (implemented)
- 🚧 Multi-node NCCL broadcast (planned)
- 🚧 Checkpoint-based sync (planned)
When enabled (method="delta"):
- Baseline snapshot taken every
snapshot_intervalsteps - Delta =
current_weights - baseline_weights - Sparse encoding if
sparsity_ratio < sparsity_target - Optional quantization (8/16-bit)
Typical compression: 60-80% bandwidth reduction for fine-tuning workloads.
| Algorithm | Type | Off-policy Support | Required Interface |
|---|---|---|---|
| PPO | On-policy | ✅ Optional (IW correction) | log_prob, value, advantage |
| GRPO | On-policy | ✅ Optional (IW correction) | log_prob, group_rewards |
| DPO | Preference | ✅ Recommended | log_prob, ref_log_prob, preference_pairs |
| REINFORCE | On-policy | ✅ Optional | log_prob, reward |
| DAPO | On-policy | ✅ Optional | log_prob, dynamic_clip_range |
| GSPO | On-policy | ✅ Optional | log_prob, group_stability_term |
| RLOO | On-policy | ✅ Optional | log_prob, leave_one_out_baseline |
Algorithm Extension Interface:
@register_adv_estimator("my_algorithm")
def compute_my_advantage(rewards, mask, **kwargs) -> tuple[Tensor, Tensor]:
"""Return (advantages, returns)"""
...
@register_policy_loss("my_loss")
def compute_my_loss(old_logp, logp, adv, mask, **kwargs) -> tuple[Tensor, dict]:
"""Return (loss, metrics_dict)"""
...The framework handles staleness, importance correction, and batching automatically.
- Python 3.10+
- CUDA 12.0+
- SGLang server (for inference)
- Megatron-LM (optional, for production distributed training)
- HuggingFace Transformers (for development/single-GPU training)
git clone https://github.com/flux-team/flux.git
cd flux
# Basic installation
pip install -e .
# With development dependencies
pip install -e ".[dev]"
# Full installation (includes all dependencies)
pip install -e ".[all]"Flux expects prompts in JSONL format:
{"prompt": "Explain quantum computing in simple terms."}
{"prompt": "Write a Python function to compute Fibonacci numbers."}
{"prompt": "What are the benefits of exercise?"}# Single GPU
python -m sglang.launch_server --model-path Qwen/Qwen3-8B --port 8000
# Multi-GPU (tensor parallel)
python -m sglang.launch_server --model-path Qwen/Qwen3-8B --port 8000 --tp 4from flux import FluxTrainer, FluxConfig
from flux.core.config import AdaptiveAsyncConfig, SGLangConfig
config = FluxConfig(
model_path="Qwen/Qwen3-8B",
sglang=SGLangConfig(base_url="http://localhost:8000"),
adaptive_async=AdaptiveAsyncConfig(
target_staleness=0.15, # Target staleness level
min_async_ratio=0.1, # Never fully synchronous
max_async_ratio=0.9, # Never fully asynchronous
kp=0.1, ki=0.01, kd=0.05, # PID controller gains
),
learning_rate=1e-6,
batch_size=32,
num_steps=10000,
)
trainer = FluxTrainer(config)
trainer.fit(
prompts="data/prompts.jsonl",
eval_prompts="data/eval_prompts.jsonl",
)[Step 100] loss=0.342 | staleness=0.12 | async_ratio=0.45 | throughput=1.2k tok/s
[Step 200] loss=0.298 | staleness=0.14 | async_ratio=0.52 | throughput=1.4k tok/s
[Step 300] loss=0.267 | staleness=0.16 | async_ratio=0.48 (sync triggered)
[Step 400] loss=0.241 | staleness=0.13 | async_ratio=0.55 | throughput=1.5k tok/s
...
[Step 1000] loss=0.156 | staleness=0.15 | async_ratio=0.62 | throughput=1.8k tok/s
↑ Controller stabilizes around target
flux train --config configs/qwen3-8b-8gpu.yaml --prompts data/prompts.jsonlExample config (configs/qwen3-8b-8gpu.yaml):
model_path: Qwen/Qwen3-8B
sglang:
base_url: http://localhost:8000
adaptive_async:
target_staleness: 0.15
min_async_ratio: 0.1
max_async_ratio: 0.9
kp: 0.1
ki: 0.01
kd: 0.05
training:
learning_rate: 1e-6
batch_size: 32
gradient_accumulation_steps: 4
algorithm: grpo # or ppo, dpo, reinforce| Mode | Training | Inference | Weight Sync | Recommended For |
|---|---|---|---|---|
| Colocated | GPU 0-3 | GPU 4-7 | CUDA IPC | Single node, low latency |
| Separated | Node A | Node B | HTTP | Multi-node, high throughput |
| Hybrid | GPU 0-5 | GPU 6-7 + Node B | Mixed | Large-scale production |
Default recommendation: Colocated on single node for simplicity; separated for 64+ GPU training.
The FluxCoordinator is a single-machine orchestrator (not a distributed control plane):
- Runs on rank 0 of the training process
- Uses
asyncioevent loop for non-blocking I/O - Coordinates via ZeroMQ (local) or HTTP (remote) to SGLang servers
- Checkpoint/resume handled by Megatron's distributed checkpointing
Fault Tolerance (current):
- Timeout on rollout requests (default: 60s)
- Retry on transient HTTP failures (3 attempts)
- Graceful degradation: skip failed rollouts, continue training
Fault Tolerance (planned):
- SGLang server health monitoring and failover
- Training checkpoint on staleness spike
- Automatic rollout worker restart
flux/
├── core/ # Core abstractions
│ ├── config.py # Pydantic configuration classes
│ ├── types.py # Type definitions
│ └── trajectory.py # Trajectory data structures
├── controller/ # Adaptive control plane
│ ├── adaptive_async.py # PID-based async ratio controller
│ ├── staleness.py # Staleness measurement
│ ├── mode_gate.py # Sync/async state machine (NEW)
│ └── importance.py # Importance weight correction
├── rollout/ # Rollout generation
│ ├── manager.py # Streaming rollout with APRIL
│ └── sglang_client.py # SGLang HTTP client
├── training/ # Training backends
│ ├── base.py # TrainingBackend ABC, GPUBatch (NEW)
│ ├── backends/ # Backend implementations (NEW)
│ │ ├── transformers.py # HuggingFace Transformers backend
│ │ └── __init__.py # Backend factory
│ ├── megatron_engine.py # Megatron-LM integration
│ ├── batch_composer.py # Smart batch composition
│ └── algorithms/ # PPO, GRPO, etc.
├── sync/ # Weight synchronization
│ ├── weight_sync.py # Sync manager
│ └── cuda_ipc.py # Zero-copy CUDA IPC
└── coordinator/ # Lightweight coordinator
└── coordinator.py # Main asyncio coordinator
Flux uses a native trainer contract that enables GPU-direct training with pluggable backends:
from flux.training import TrainingBackend, GPUBatch, create_training_backend
# Create backend from config (TransformersBackend, MegatronBackend, etc.)
backend = create_training_backend(config)
backend.initialize(training_config)
# Training loop with GPU-direct batches
for batch in trajectory_store.sample_batches():
gpu_batch = batch.as_gpu_batch(backend.device)
result = backend.train_step(gpu_batch)
if result.version % sync_interval == 0:
sync_weights(backend.get_state_dict())| Backend | Use Case | Parallelism |
|---|---|---|
| TransformersBackend | Development, single/multi-GPU | DataParallel, DDP |
| MegatronBackend | Production, large models | 3D (TP + PP + DP) |
| FSDPBackend | Memory-efficient large models | FSDP sharding |
The ModeGate controls transitions between sync and async training modes:
┌──────────────────┐
│ ASYNC_RUNNING │ ◄── Normal operation
└────────┬─────────┘
│
┌────────────────┼────────────────┐
│ │ │
▼ ▼ ▼
staleness > threshold capacity = 0 buffer > 90%
│ │ │
▼ ▼ ▼
┌───────────────┐ ┌─────────────┐ ┌─────────────┐
│ SYNC_BARRIER │ │ THROTTLED │ │ THROTTLED │
│ (wait for │ │ (backpres- │ │ (backpres- │
│ in-flight) │ │ sure) │ │ sure) │
└───────────────┘ └─────────────┘ └─────────────┘
The gate evaluates staleness, capacity, and buffer fill ratio to determine whether to:
- Continue async operation
- Trigger a sync barrier (wait for all in-flight rollouts)
- Apply backpressure (pause new rollout submissions)
Flux operates anywhere on the sync-async spectrum, adapting in real-time based on training dynamics.
- Direct Megatron-LM integration — not wrapped
- SGLang HTTP API — simple and efficient
- Pure asyncio — no Ray abstraction layer
- < 5000 lines of core runtime code (excludes backends and tests)
- No magic, explicit control flow
- Easy to debug and extend
| Metric | Target | Measurement Method |
|---|---|---|
| GPU Utilization | > 80% | nvidia-smi dmon average |
| Throughput | 2× VERL baseline | samples/hour on same hardware |
| Staleness | Mean < 0.2, Max < 0.4 | Combined staleness metric |
| KL Blow-up | < 5% of runs | KL divergence spike detection |
| Scaling Efficiency | > 85% at 64 GPUs | Throughput ratio vs linear |
Note: These are design targets. Benchmarks in progress.
Full documentation: https://ivis4ml.github.io/flux/
| Resource | Description |
|---|---|
| Getting Started | Installation and quick start guide |
| Tutorials | Step-by-step training tutorials |
| Algorithms | PPO, GRPO, DPO, and custom algorithms |
| Configuration | Full configuration reference |
| API Reference | Complete API documentation |
# Install with dev dependencies
pip install -e ".[dev]"
# Run tests
pytest # all tests
pytest tests/unit/ # unit tests only
pytest -m "not slow" # skip slow tests
# Code quality
ruff check . && black --check . && mypy flux/We welcome contributions! Please see our contributing guidelines.
# Setup development environment
pip install -e ".[dev]"
# Run all checks before commit
ruff check . && black --check . && mypy flux/ && pytest@software{flux2025,
title = {Flux: An Adaptive Post-Training Framework for LLMs},
year = {2025},
url = {https://github.com/flux-team/flux}
}Apache 2.0
