A research-grade implementation of a continuous-depth Transformer language model using Neural ODEs in Julia. This project implements a novel architecture where hidden states evolve continuously through time/depth, parameterized by Transformer-style dynamics.
Traditional Transformers process sequences through discrete layers. This project explores an alternative: continuous-time evolution of hidden states via Neural ODEs. The model integrates an ODE dh/dt = f(h, t, θ) where f is parameterized by self-attention and feedforward blocks.
tokens → embeddings → Neural ODE Transformer → LM head → logits
↓
dh/dt = TransformerBlock(h, t)
- Continuous-depth processing via ODE integration (DifferentialEquations.jl)
- Proper adjoint sensitivity methods for efficient backpropagation (InterpolatingAdjoint, BacksolveAdjoint, QuadratureAdjoint)
- Multiple ODE solvers (Tsit5, Vern7, Vern9, BS3, DP5, KenCarp4, TRBDF2, Euler, RK4)
- Custom continuous-attention kernel integrator (RK4-style fixed-step integration)
- Reversible ODE design for memory-efficient training
- KV caching for fast autoregressive generation
- TensorBoard logging for experiment tracking
- Discrete Transformer baseline for comparison
- Full training pipeline with checkpointing and validation
- Text generation with multiple sampling strategies (greedy, top-k, top-p)
- GPU support via CUDA.jl + cuDNN for full acceleration
- Large-scale datasets (440K+ words: literature, science, code)
- Type-stable, idiomatic Julia code
- Comprehensive benchmarking and performance profiling
- Modular architecture for easy experimentation
- Julia 1.10+
- CUDA-capable GPU (optional, but recommended for larger models)
v0.2.0 - Major Enhancements:
- Large-scale datasets: 440K+ word corpus (literature + science + code)
- GPU acceleration: Full CUDA + cuDNN support
- Extended ODE solvers: 10+ solvers (Vern7, KenCarp4, TRBDF2, etc.)
- Advanced adjoint methods: Multiple sensitivity methods for gradients
- Fixed text generation: Autoregressive sampling with sliding windows
- Comprehensive benchmarking: Performance comparison tools
- Neural ODE stability: Custom RK4 integrator working reliably
# Clone the repository
git clone <repo-url>
cd ContinuumLM
# Activate Julia environment
julia --project=.
# Install dependencies
julia -e 'using Pkg; Pkg.instantiate()'Create a text corpus file:
mkdir -p data
# Add your text corpus to data/corpus.txtEnhanced training with large dataset (440K+ words):
# Download/create datasets
julia scripts/download_data.jl
# Train with enhanced dataset
julia scripts/standalone_train.jl config/small_debug.tomlRun complete demo (shows all features):
julia scripts/demo.jlCompare Neural ODE vs Transformer performance:
julia scripts/comprehensive_benchmark.jlOriginal training scripts:
# Small debug model (fast, CPU-friendly)
julia scripts/train_neural_ode_lm.jl config/small_debug.toml
# Neural ODE Transformer
julia scripts/train_neural_ode_lm.jl config/neural_ode_transformer.toml
# Discrete Transformer baseline
julia scripts/train_neural_ode_lm.jl config/base_transformer.tomljulia scripts/evaluate.jl config/neural_ode_transformer.toml checkpoints/best_model.bsonjulia scripts/generate.jl config/neural_ode_transformer.toml checkpoints/best_model.bson "Once upon a time"With custom sampling:
julia scripts/generate.jl config/neural_ode_transformer.toml checkpoints/best_model.bson "The future of AI" --max_tokens 200 --temperature 0.8 --top_k 50 --top_p 0.9Start TensorBoard to visualize training metrics:
tensorboard --logdir logsThen open http://localhost:6006 in your browser to view:
- Training/validation loss curves
- Perplexity metrics
- Learning rate schedule
- Gradient norms
julia --project=. test/runtests.jlContinuumLM/
├── src/
│ ├── NeuralODELM.jl # Main module
│ ├── Config.jl # Configuration management
│ ├── Utils.jl # Device selection, seeding
│ ├── Data.jl # Tokenization and batching
│ ├── Training.jl # Training loop, checkpointing
│ ├── Evaluation.jl # Perplexity, validation metrics
│ ├── Generation.jl # Text generation utilities
│ └── Models/
│ ├── Models.jl # Model exports
│ ├── Embeddings.jl # Token + positional embeddings
│ ├── Attention.jl # Multi-head self-attention
│ ├── ContinuousTransformer.jl # Discrete stack baseline
│ ├── NeuralODEBlock.jl # Continuous-time ODE block
│ └── LanguageModel.jl # End-to-end LM composition
├── scripts/
│ ├── train_neural_ode_lm.jl # Training entrypoint
│ ├── evaluate.jl # Evaluation script
│ └── generate.jl # Generation script
├── config/
│ ├── small_debug.toml # Tiny model for debugging
│ ├── neural_ode_transformer.toml # Neural ODE config
│ └── base_transformer.toml # Discrete baseline config
├── test/
│ ├── runtests.jl # Test suite
│ ├── test_data.jl # Data pipeline tests
│ ├── test_models.jl # Model component tests
│ └── test_training.jl # Training loop tests
└── README.md # This file
Configuration files use TOML format. Key settings:
d_model: Hidden dimensionn_heads: Number of attention headsd_ff: Feedforward dimensionvocab_size: Vocabulary sizeis_neural_ode: Use Neural ODE (true) or discrete stack (false)ode_t0,ode_t1: ODE integration time intervalode_solver: ODE solver ("Tsit5", "RK4", "Euler")ode_sensealg: Adjoint sensitivity method ("InterpolatingAdjoint", "BacksolveAdjoint", "QuadratureAdjoint")ode_integrator: Integration mode ("generic" or "custom_fixed_step")ode_nsteps: Number of steps for custom integrator (default: 4)reversible: Use reversible ODE for memory efficiency (default: false)ode_atol,ode_rtol: ODE solver tolerances
batch_size: Batch sizeseq_len: Sequence lengthnum_steps: Total training stepslr: Learning rateweight_decay: Weight decay for AdamWgrad_clip: Gradient clipping thresholdwarmup_steps: Learning rate warmup stepsdevice: "cpu", "gpu", or "auto"log_dir: Directory for TensorBoard logs (default: "logs")run_name: Name for this training run (default: "default_run")
The core innovation is the NeuralODEBlock, which:
- Takes hidden state
h(t)at deptht - Computes derivative
dh/dt = TransformerBlock(h, t) - Integrates from
t=0tot=Tusing an ODE solver - Returns the transformed state
h(T)
This replaces discrete layer stacking with continuous evolution, allowing the model to learn adaptive depth.
Adjoint Sensitivity Methods:
- Uses
InterpolatingAdjointorBacksolveAdjointfor efficient gradient computation - Avoids storing full forward trajectory during backpropagation
- Configurable via
ode_sensealgin config
Custom Continuous-Attention Kernel:
- Optional RK4-style fixed-step integrator (
ode_integrator = "custom_fixed_step") - Tailored specifically for Transformer dynamics
- Configurable number of steps via
ode_nsteps
Reversible ODE:
- Memory-efficient training with
reversible = true - Automatically uses
BacksolveAdjointfor optimal memory usage - Reconstructs intermediate states on-the-fly during backprop
KV Caching:
- Efficient autoregressive generation with cached keys/values
- Avoids recomputing attention for previous tokens
- Use
generate_text_with_cache()for faster inference
- Discrete Transformer:
h_{i+1} = TransformerBlock(h_i)fori=1..N - Neural ODE:
h(T) = h(0) + ∫₀ᵀ TransformerBlock(h(t), t) dt
The continuous formulation can be more parameter-efficient and theoretically allows for adaptive depth.
This is a research scaffold, not a production LLM. Current limitations:
- Small model sizes (for research/education)
- Basic tokenization (word-level)
- Limited dataset support
Potential extensions:
- KV caching for Neural ODE path
- Larger model scales
- Advanced ODE solvers and adjoint methods
- Additional regularization techniques
- Multi-GPU training
- Integration with HuggingFace tokenizers
- Neural ODEs: Chen et al., "Neural Ordinary Differential Equations" (NeurIPS 2018)
- Continuous Normalizing Flows: Grathwohl et al., "FFJORD" (ICLR 2019)
- Transformers: Vaswani et al., "Attention Is All You Need" (NeurIPS 2017)
MIT License (or as specified in your project)
This is a research codebase. Contributions welcome! Areas for improvement:
- Performance optimizations
- Additional ODE solvers
- Better tokenization support
- More comprehensive tests
- Documentation improvements
Built with:
- Flux.jl - Deep learning framework
- DifferentialEquations.jl - ODE solving
- DiffEqFlux.jl - Neural ODE integration
Note: This project is for research and educational purposes. For production language models, consider established frameworks like Transformers.jl or PyTorch implementations.