Skip to content

Feature Enhancement: Add Advanced Post-Training and Fine-Tuning Capabilities to Simply #2

@vector-one

Description

@vector-one

Issue Summary

Simply is an excellent minimal and scalable research codebase for LLM training in JAX, but it currently lacks comprehensive post-training capabilities (RLHF, DPO, PPO) and advanced fine-tuning features that are standard in comparable JAX LLM frameworks. This issue proposes adding these capabilities while maintaining Simply's core philosophy of minimalism and rapid iteration.

Background & Motivation

Simply positions itself as a "minimal and scalable research codebase" focused on pre-training and basic training of autoregressive models. However, when compared to other mature JAX-based LLM frameworks, Simply is missing critical post-training capabilities that are essential for modern LLM research and deployment.

Current State of Simply

What Simply Has:

  • Clean JAX/Flax-based architecture
  • Basic transformer implementations
  • Orbax checkpoint management
  • SeqIO data pipeline integration
  • Fast pre-training capabilities
  • Minimal abstractions for quick iteration

What Simply Lacks:

  • Post-training algorithms (RLHF, PPO, DPO, GRPO)
  • Preference tuning capabilities
  • Knowledge distillation
  • Supervised fine-tuning (SFT) utilities
  • Advanced evaluation frameworks
  • Model serving utilities

Comparison with Similar JAX LLM Frameworks

1. MaxText (Google DeepMind)

Features Simply is Missing:

  • Post-Training Suite: MaxText integrates with Tunix for comprehensive post-training

    • Supervised Fine-Tuning (SFT)
    • Group Relative Policy Optimization (GRPO)
    • Preference tuning
    • Knowledge distillation
  • Multi-Modal Support: Gemma 3 and Llama 4 VLM training

  • Advanced Data Pipeline:

    • Multiple input pipeline options (Grain with ArrayRecord, TFDS, HuggingFace)
    • Global shuffle for ArrayRecord format
    • Configurable data sharding strategies
  • Production Features:

    • Automatic log upload to Vertex AI Tensorboard
    • Stack trace collection for debugging distributed training
    • Comprehensive XLA compiler flag configurations
    • Integration with vLLM for serving
  • Scale Optimization: Proven to scale to 51K+ chips with high MFU (55-60%)

2. Levanter (Stanford CRFM)

Features Simply is Missing:

  • Named Tensors via Haliax: Dramatically improves code legibility while maintaining performance

    • Eliminates positional index confusion
    • Makes tensor operations self-documenting
    • Simplifies parallelism strategies
  • Bitwise Reproducibility: Deterministic training even with preemption and restarts

    • Critical for research reproducibility
    • Enables exact experiment replication
  • Flexible Data Handling:

    • Online but cached data preprocessing
    • Ability to tune data mixtures without retokenization
    • Sharded data loading
  • Advanced Checkpointing:

    • Can resume training on different number of hosts
    • Distributed checkpointing via TensorStore
  • Hugging Face Integration:

    • Automatic export to Hugging Face Hub
    • Import/export via SafeTensors
    • Compatible with PyTorch ecosystem
  • Optimizer Innovation: Support for Sophia optimizer (2x faster than Adam)

3. EasyLM

Features Simply is Missing:

  • RLHF Methods:

    • Direct Preference Optimization (DPO)
    • Proximal Policy Optimization (PPO)
  • Modular Configuration System: Built on MLXU for easy flag-based configuration

  • Comprehensive Model Support: Pre-built implementations with fine-tuning examples

  • Serving Capabilities: Built-in serving utilities for deployment

4. Tunix (Google - Post-Training Focused)

Features Simply is Missing:

  • Complete Post-Training Algorithm Suite:

    • Supervised Fine-Tuning (SFT)
    • Preference tuning
    • Knowledge distillation (logit strategy, attention transfer, feature pooling)
    • PPO, GRPO, GSPO for reinforcement learning
  • Agentic AI Support: Training agents that reason with LLMs and interact with environments

  • White-Box Design: Easy customization of training loops without abstraction layers

  • Game Reinforcement Learning: Multi-turn RL on challenging games via GRL integration

Proposed Solution

Add a new post-training module to Simply that maintains its minimal philosophy while providing essential post-training capabilities. The implementation should be:

  1. Minimal but Complete: Core algorithms without excessive abstraction
  2. Easy to Fork: Following Simply's design philosophy
  3. Self-Contained: Minimal new dependencies beyond what Simply already uses
  4. Well-Documented: Clear examples for each algorithm

Implementation Plan

Phase 1: Foundation (Core Infrastructure)

1.1 Reward Modeling Infrastructure

  • Create simply/posttraining/ directory structure
  • Implement base reward model class using Flax
  • Add reward model training utilities
  • Support for preference data loading via SeqIO

1.2 Configuration System Enhancement

  • Extend existing config system to support post-training parameters
  • Add configs for SFT, DPO, PPO hyperparameters
  • Maintain Simply's flat, minimal config style

Files to Create:

simply/posttraining/
├── __init__.py
├── reward_model.py
├── base_trainer.py
└── configs/
    └── posttraining_config.py

Phase 2: Supervised Fine-Tuning (SFT)

2.1 SFT Trainer Implementation

  • Create SFT trainer that extends base training loop
  • Support for instruction-tuning datasets
  • Loss masking for instruction/response separation
  • Integration with existing checkpoint system

2.2 Data Pipeline Enhancement

  • Add utilities for instruction-tuning data formats
  • Support for common formats (Alpaca, ShareGPT)
  • Preprocessing for prompt/completion pairs

Files to Create:

simply/posttraining/
├── sft_trainer.py
├── data/
│   ├── __init__.py
│   ├── instruction_data.py
│   └── preprocessors.py

Example Usage:

# Simply philosophy: minimal code, maximum clarity
python -m simply.main \
  --experiment_config TransformerSFT \
  --training_mode sft \
  --instruction_dataset alpaca \
  --experiment_dir /tmp/sft_experiment

Phase 3: Direct Preference Optimization (DPO)

3.1 DPO Implementation

  • Implement DPO loss function
  • Reference model management (frozen copy)
  • Preference pair data loading
  • Beta hyperparameter tuning support

3.2 DPO Training Loop

  • Integrate DPO loss into training loop
  • Memory-efficient reference model handling
  • Logging of preference accuracy metrics

Files to Create:

simply/posttraining/
├── dpo_trainer.py
├── losses.py (for DPO loss)
└── data/preference_data.py

Key Implementation Details:

# DPO Loss (simplified, following Simply's minimal style)
def dpo_loss(policy_chosen_logps, policy_rejected_logps,
             reference_chosen_logps, reference_rejected_logps, beta=0.1):
    """Minimal DPO loss implementation"""
    policy_logratios = policy_chosen_logps - policy_rejected_logps
    reference_logratios = reference_chosen_logps - reference_rejected_logps
    losses = -jax.nn.log_sigmoid(beta * (policy_logratios - reference_logratios))
    return losses.mean()

Phase 4: Proximal Policy Optimization (PPO)

4.1 PPO Infrastructure

  • Actor-critic architecture
  • Value head implementation
  • Advantage computation (GAE)
  • PPO clipping objective

4.2 Rollout Generation

  • Efficient batch generation
  • KL penalty computation
  • Reward integration
  • Memory-efficient rollout buffer

4.3 PPO Training Loop

  • Multiple optimization epochs per batch
  • Value function training
  • Policy updates with clipping
  • KL divergence monitoring

Files to Create:

simply/posttraining/
├── ppo_trainer.py
├── value_head.py
├── rollout_buffer.py
└── advantages.py

Architecture:

# Value head (minimal implementation)
class ValueHead(nn.Module):
    """Scalar value prediction head for PPO"""
    hidden_dim: int = 1024
    
    @nn.compact
    def __call__(self, hidden_states):
        x = nn.Dense(self.hidden_dim)(hidden_states[:, -1])  # Take last token
        x = nn.relu(x)
        value = nn.Dense(1)(x)
        return value.squeeze(-1)

Phase 5: Evaluation & Serving Utilities

5.1 Evaluation Framework

  • Integration with existing eval metrics
  • Post-training specific metrics (win rates, preference accuracy)
  • Generation quality assessment utilities
  • Human preference simulation

5.2 Model Export

  • SafeTensors export functionality
  • Hugging Face Hub compatibility
  • Checkpoint conversion utilities

Files to Create:

simply/
├── eval/
│   ├── posttraining_metrics.py
│   └── generation_eval.py
└── export/
    ├── __init__.py
    ├── safetensors_export.py
    └── hf_export.py

Phase 6: Documentation & Examples

6.1 Documentation

  • Post-training guide in README
  • Algorithm explanations with references
  • Hyperparameter tuning guidelines
  • Common pitfalls and solutions

6.2 Examples

  • End-to-end SFT example
  • DPO fine-tuning example
  • PPO from reward model example
  • Dataset preparation scripts

Files to Create:

docs/
├── posttraining_guide.md
├── dpo_tutorial.md
├── ppo_tutorial.md
└── sft_tutorial.md

examples/
├── sft_example.py
├── dpo_example.py
└── ppo_example.py

Dependencies to Add

Minimal new dependencies, leveraging existing Simply infrastructure:

# In requirements.txt, add only:
# wandb (for experiment tracking - optional)
# datasets (for preference/instruction data - already uses SeqIO, so minimal addition)

Most functionality can be built on existing dependencies:

  • JAX (already included)
  • Flax (already included)
  • Optax (already included)
  • Orbax (already included)

Expected Benefits

For Researchers

  1. Rapid Prototyping: Quickly iterate on post-training ideas
  2. Minimal Learning Curve: If you know Simply's pre-training, you know post-training
  3. Research Quality: Reproducible results for publications
  4. Flexibility: Easy to modify algorithms for novel research

For Practitioners

  1. Production-Ready Models: Fine-tune for deployment
  2. Alignment Capabilities: RLHF for safer models
  3. Cost Effective: Efficient post-training reduces compute needs
  4. Easy Integration: Seamless workflow from pre-training to deployment

Maintaining Simply's Philosophy

  1. Minimal Abstractions: Direct implementations, not frameworks-within-frameworks
  2. Fork-Friendly: Each algorithm is ~200-400 lines, easy to modify
  3. Self-Contained: No hidden magic, everything visible
  4. Fast Iteration: Hours, not days, to implement research ideas

Success Metrics

  1. Code Simplicity: Each algorithm implementable in <500 lines
  2. Time to First Result: <1 hour from clone to running post-training
  3. Performance: Comparable results to MaxText/Tunix on standard benchmarks
  4. Adoption: Community contributions and forks
  5. Documentation Quality: Users can run examples without reading source

Testing Strategy

Unit Tests

  • Individual loss functions (DPO, PPO objectives)
  • Data preprocessing utilities
  • Reward model components

Integration Tests

  • End-to-end SFT training
  • DPO training loop
  • PPO training loop
  • Checkpoint save/restore

Benchmark Tests

  • GSM8K with GRPO (target: ~12% improvement like Tunix)
  • Anthropic HH-RLHF with DPO
  • AlpacaEval for SFT quality

Alternative Approaches Considered

1. External Integration

Approach: Recommend users switch to Tunix for post-training

  • ❌ Breaks Simply's "one-stop" philosophy
  • ❌ Requires learning another codebase
  • ❌ Loses Simply's simplicity advantage

2. Wrapper Around Tunix

Approach: Create thin wrapper around Tunix

  • ❌ Adds abstraction layers (against Simply's philosophy)
  • ❌ Introduces heavy dependencies
  • ❌ Less hackable/forkable

3. Minimal Post-Training Only

Approach: Add only SFT, skip RLHF

  • ❌ Incomplete for modern LLM workflows
  • ❌ Users still need external tools
  • ✓ Maintains simplicity (partial benefit)

4. Full-Featured (Proposed)

Approach: Add SFT, DPO, PPO with minimal implementation

  • ✓ Complete workflow in one codebase
  • ✓ Maintains Simply's philosophy
  • ✓ Enables full research cycle
  • ✓ Competitive with other frameworks

Implementation Considerations

Memory Efficiency

  • Use JAX's memory optimization features
  • Gradient checkpointing for large models
  • Efficient rollout buffers for PPO
  • Reference model sharing for DPO

Scalability

  • Support for FSDP (already in Simply)
  • Multi-host training for large preference datasets
  • Efficient data loading for instruction datasets

Compatibility

  • Maintain backward compatibility with existing Simply code
  • Optional imports for post-training (don't break pre-training users)
  • Config-based enabling of features

Migration Path for Users

Before (Pre-Training Only)

python -m simply.main \
  --experiment_config TransformerLM \
  --experiment_dir /tmp/pretrain

After (With Post-Training)

# Still works exactly the same for pre-training
python -m simply.main \
  --experiment_config TransformerLM \
  --experiment_dir /tmp/pretrain

# New: Post-training on your checkpoint
python -m simply.main \
  --experiment_config TransformerSFT \
  --training_mode sft \
  --load_checkpoint /tmp/pretrain/checkpoint_1000 \
  --experiment_dir /tmp/finetune

Timeline Estimate

Note: No specific timeline provided as requested, but complexity estimates included

  • Phase 1 (Foundation): ~2-3 weeks (moderate complexity)
  • Phase 2 (SFT): ~1-2 weeks (low complexity)
  • Phase 3 (DPO): ~2 weeks (moderate complexity)
  • Phase 4 (PPO): ~3-4 weeks (high complexity)
  • Phase 5 (Eval/Export): ~1-2 weeks (low-moderate complexity)
  • Phase 6 (Documentation): ~1 week (ongoing)

Complexity Factors:

  • Low: Well-established algorithms, clear implementation path
  • Moderate: Requires careful design, some edge cases
  • High: Complex algorithm with many moving parts, requires extensive testing

References & Prior Art

Academic Papers

Implementation References

Open Questions

  1. Reward Model Training: Should we provide reward model pre-training utilities, or assume users bring pre-trained reward models?

    • Recommendation: Start with assumption of external reward models, add training later if needed
  2. Multi-Modal Support: Should post-training support multi-modal models immediately, or focus on LLM-only?

    • Recommendation: LLM-only first, add multi-modal as Phase 7 if there's demand
  3. Serving: Should we include inference serving utilities, or focus purely on training?

    • Recommendation: Include export utilities only, recommend vLLM for serving
  4. Hyperparameter Defaults: How opinionated should we be about hyperparameters?

    • Recommendation: Provide sensible defaults from literature, but make everything configurable

Community Engagement

To ensure this feature meets community needs:

  1. RFC Period: Gather feedback on this proposal
  2. Design Review: Share detailed design docs before implementation
  3. Incremental Releases: Ship phases incrementally for community testing
  4. Example Showcase: Run examples on common benchmarks, share results
  5. Tutorial Series: Blog posts explaining each algorithm

Conclusion

Adding post-training capabilities to Simply will transform it from a pre-training-focused library into a complete LLM research and deployment toolkit, while maintaining its core philosophy of minimalism and ease of use. This positions Simply competitively against MaxText, Levanter, and EasyLM, while preserving what makes Simply special: simplicity, hackability, and rapid iteration speed.

The proposed implementation is designed to be:

  • Minimal: Core algorithms without unnecessary abstraction
  • Scalable: Works from single GPU to large TPU pods
  • Maintainable: Clear, simple code that's easy to modify
  • Complete: Covers the full spectrum of modern LLM post-training

By following the phased implementation plan, Simply can gain these capabilities without sacrificing its essential character, making it the go-to choice for researchers who want a clean, simple, yet complete LLM training framework in JAX.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions