A high-performance implementation of Proximal Policy Optimization using JAX and Equinox, featuring state-of-the-art memory-enhanced neural network architectures for partially observable environments.
- Dual Training Modes: Standard PPO and PPO with Recurrent Neural Networks (PPO-RNN)
- Advanced Memory Models: 15+ cutting-edge memory architectures including LRU, LSTM, GRU, and novel research models
- PopGym Arcade Integration: Built-in support for partially observable RL environments
- High Performance: Vectorized JAX implementation with JIT compilation
- Experiment Tracking: Integrated Weights & Biases (WandB) logging
- Flexible Configuration: Comprehensive hyperparameter control via command-line interface
- LSTM: Long Short-Term Memory
- GRU: Gated Recurrent Unit
- MGU: Minimal Gated Unit
- Elman: Elman Recurrent Network
- LRU: Linear Recurrent Unit
- MinGRU: Minimal GRU variant
- Linear RNN: Simple linear recurrence
- FART: Fast Autoregressive Recurrent Transformer
- GILR: Gated Information Linear Recurrence
- NAbs: Neural Absolute networks
- NMax: Neural Max networks
- PSpherical: Polar Spherical networks
- LogBayes: Logarithmic Bayesian networks
- DLSE: Decaying LogSumExp
- S6: Structured State Space model variant
- Python 3.8+
- JAX (with GPU support recommended)
pip install jax[cuda12] # or jax[cpu] for CPU-only
pip install equinox
pip install optax
pip install popgym-arcade
pip install wandb
pip install distreqx
pip install numpygit clone https://github.com/Aphlios/Proximal-Policy-Optimization.git
cd Proximal-Policy-Optimizationpython train.py PPO \
--ENV_NAME CartPoleHard \
--TOTAL_TIMESTEPS 1000000 \
--LR 3e-4 \
--NUM_ENVS 16 \
--SEED 42python train.py PPO_RNN \
--MEMORY_TYPE lru \
--ENV_NAME CartPoleHard \
--TOTAL_TIMESTEPS 1000000 \
--LR 3e-4 \
--NUM_ENVS 16 \
--PARTIAL \
--SEED 42--LR: Learning rate (default: 1e-4)--TOTAL_TIMESTEPS: Total training timesteps (default: 1e7)--NUM_ENVS: Number of parallel environments (default: 16)--NUM_STEPS: Steps per environment per update (default: 128)--UPDATE_EPOCHS: PPO update epochs (default: 4)--NUM_MINIBATCHES: Minibatches per update (default: 16)
--GAMMA: Discount factor (default: 0.99)--GAE_LAMBDA: GAE lambda (default: 0.95)--CLIP_EPS: PPO clipping epsilon (default: 0.2)--ENT_COEF: Entropy coefficient (default: 0.01)--VF_COEF: Value function coefficient (default: 0.5)--MAX_GRAD_NORM: Gradient clipping (default: 0.5)
--ENV_NAME: Environment name (default: "CartPoleHard")--PARTIAL: Enable partial observations--OBS_SIZE: Observation size (128 or 256)
--MEMORY_TYPE: Choose from:- Traditional:
lstm,gru,mgu,elman - Modern:
lru,mingru,linear_rnn - Research:
fart,gilr,nabs,nmax,pspherical,log_bayes,dlse,s6
- Traditional:
--PROJECT: WandB project name--ENTITY: WandB entity name--WANDB_MODE: WandB mode (online/offline/disabled)
- ActorCritic Network: CNN feature extraction + MLP heads
- Dual Pathways: Separate actor and critic CNN backbones
- Adaptive Learning: Optional learning rate annealing
- Memory-Enhanced Architecture: CNN + Recurrent layers + MLP heads
- Flexible Memory Backend: Plug-and-play memory model selection
- Stateful Processing: Maintains hidden states across timesteps
The project includes a sophisticated framework for memory models based on Generalized Recurrent Algebraic Structures (GRAS):
- Unified Interface: All memory models implement the same GRAS interface
- Algebraic Foundation: Built on mathematical structures (semigroups, monoids)
- Residual Connections: Multi-layer models with skip connections
- Reset Mechanism: Proper handling of episode boundaries
The implementation supports PopGym Arcade environments, particularly:
- CartPole variants
- Partially observable control tasks
- Visual observation spaces (128x128 or 256x256)
This implementation is designed for research in:
- Partial Observability: POMDP environments requiring memory
- Memory Architectures: Comparative studies of recurrent models
- Reinforcement Learning: Policy gradient methods with function approximation
- JAX Vectorization: Efficient parallel environment stepping
- JIT Compilation: Optimized computation graphs
- Gradient Clipping: Stable training with configurable norms
- Automatic Evaluation: Built-in video generation and logging
Trained models are automatically saved as .pkl files with descriptive names:
PPO_RNN_lru_CartPoleHard_model_Partial=False_SEED=0.pkl
# Standard PPO training
python train.py PPO --ENV_NAME CartPoleHard --TOTAL_TIMESTEPS 1000000
# Memory-enhanced training with LRU
python train.py PPO_RNN --MEMORY_TYPE lru --PARTIAL --ENV_NAME CartPoleHard
# Research model comparison
python train.py PPO_RNN --MEMORY_TYPE fart --PARTIAL --LR 1e-3Contributions are welcome! Areas of interest:
- New memory model implementations
- Additional environment support
- Performance optimizations
- Documentation improvements
This project is licensed under the MIT License - see the LICENSE file for details.
- Built with JAX and Equinox
- Memory models inspired by recent research in efficient sequence modeling
- Environment support via POPGym Arcade
If you use this code in your research, please cite:
@article{wang2025popgym,
title={POPGym Arcade: Parallel Pixelated POMDPs},
author={Wang, Zekang and He, Zhe and Zhang, Borong and Toledo, Edan and Morad, Steven},
journal={arXiv preprint arXiv:2503.01450},
year={2025}
}