Skip to content

Aphlios/Proximal-Policy-Optimization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

12 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Proximal Policy Optimization (PPO) with Memory-Enhanced Models

A high-performance implementation of Proximal Policy Optimization using JAX and Equinox, featuring state-of-the-art memory-enhanced neural network architectures for partially observable environments.

๐Ÿš€ Features

  • Dual Training Modes: Standard PPO and PPO with Recurrent Neural Networks (PPO-RNN)
  • Advanced Memory Models: 15+ cutting-edge memory architectures including LRU, LSTM, GRU, and novel research models
  • PopGym Arcade Integration: Built-in support for partially observable RL environments
  • High Performance: Vectorized JAX implementation with JIT compilation
  • Experiment Tracking: Integrated Weights & Biases (WandB) logging
  • Flexible Configuration: Comprehensive hyperparameter control via command-line interface

๐Ÿง  Supported Memory Models

Traditional RNNs

  • LSTM: Long Short-Term Memory
  • GRU: Gated Recurrent Unit
  • MGU: Minimal Gated Unit
  • Elman: Elman Recurrent Network

Modern Linear Models

  • LRU: Linear Recurrent Unit
  • MinGRU: Minimal GRU variant
  • Linear RNN: Simple linear recurrence

Research Models

  • FART: Fast Autoregressive Recurrent Transformer
  • GILR: Gated Information Linear Recurrence
  • NAbs: Neural Absolute networks
  • NMax: Neural Max networks
  • PSpherical: Polar Spherical networks
  • LogBayes: Logarithmic Bayesian networks
  • DLSE: Decaying LogSumExp
  • S6: Structured State Space model variant

๐Ÿ“ฆ Installation

Prerequisites

  • Python 3.8+
  • JAX (with GPU support recommended)

Dependencies

pip install jax[cuda12]  # or jax[cpu] for CPU-only
pip install equinox
pip install optax
pip install popgym-arcade
pip install wandb
pip install distreqx
pip install numpy

Clone Repository

git clone https://github.com/Aphlios/Proximal-Policy-Optimization.git
cd Proximal-Policy-Optimization

๐ŸŽฎ Usage

Standard PPO Training

python train.py PPO \
    --ENV_NAME CartPoleHard \
    --TOTAL_TIMESTEPS 1000000 \
    --LR 3e-4 \
    --NUM_ENVS 16 \
    --SEED 42

PPO with Memory Models

python train.py PPO_RNN \
    --MEMORY_TYPE lru \
    --ENV_NAME CartPoleHard \
    --TOTAL_TIMESTEPS 1000000 \
    --LR 3e-4 \
    --NUM_ENVS 16 \
    --PARTIAL \
    --SEED 42

Configuration Options

Core Training Parameters

  • --LR: Learning rate (default: 1e-4)
  • --TOTAL_TIMESTEPS: Total training timesteps (default: 1e7)
  • --NUM_ENVS: Number of parallel environments (default: 16)
  • --NUM_STEPS: Steps per environment per update (default: 128)
  • --UPDATE_EPOCHS: PPO update epochs (default: 4)
  • --NUM_MINIBATCHES: Minibatches per update (default: 16)

PPO Hyperparameters

  • --GAMMA: Discount factor (default: 0.99)
  • --GAE_LAMBDA: GAE lambda (default: 0.95)
  • --CLIP_EPS: PPO clipping epsilon (default: 0.2)
  • --ENT_COEF: Entropy coefficient (default: 0.01)
  • --VF_COEF: Value function coefficient (default: 0.5)
  • --MAX_GRAD_NORM: Gradient clipping (default: 0.5)

Environment Configuration

  • --ENV_NAME: Environment name (default: "CartPoleHard")
  • --PARTIAL: Enable partial observations
  • --OBS_SIZE: Observation size (128 or 256)

Memory Model Selection (PPO_RNN only)

  • --MEMORY_TYPE: Choose from:
    • Traditional: lstm, gru, mgu, elman
    • Modern: lru, mingru, linear_rnn
    • Research: fart, gilr, nabs, nmax, pspherical, log_bayes, dlse, s6

Experiment Tracking

  • --PROJECT: WandB project name
  • --ENTITY: WandB entity name
  • --WANDB_MODE: WandB mode (online/offline/disabled)

๐Ÿ—๏ธ Architecture

Standard PPO

  • ActorCritic Network: CNN feature extraction + MLP heads
  • Dual Pathways: Separate actor and critic CNN backbones
  • Adaptive Learning: Optional learning rate annealing

PPO-RNN

  • Memory-Enhanced Architecture: CNN + Recurrent layers + MLP heads
  • Flexible Memory Backend: Plug-and-play memory model selection
  • Stateful Processing: Maintains hidden states across timesteps

Memory Model Framework (Memorax)

The project includes a sophisticated framework for memory models based on Generalized Recurrent Algebraic Structures (GRAS):

  • Unified Interface: All memory models implement the same GRAS interface
  • Algebraic Foundation: Built on mathematical structures (semigroups, monoids)
  • Residual Connections: Multi-layer models with skip connections
  • Reset Mechanism: Proper handling of episode boundaries

๐Ÿ“Š Supported Environments

The implementation supports PopGym Arcade environments, particularly:

  • CartPole variants
  • Partially observable control tasks
  • Visual observation spaces (128x128 or 256x256)

๐Ÿ”ฌ Research Context

This implementation is designed for research in:

  • Partial Observability: POMDP environments requiring memory
  • Memory Architectures: Comparative studies of recurrent models
  • Reinforcement Learning: Policy gradient methods with function approximation

๐Ÿ“ˆ Performance Features

  • JAX Vectorization: Efficient parallel environment stepping
  • JIT Compilation: Optimized computation graphs
  • Gradient Clipping: Stable training with configurable norms
  • Automatic Evaluation: Built-in video generation and logging

๐Ÿ› ๏ธ Model Persistence

Trained models are automatically saved as .pkl files with descriptive names:

PPO_RNN_lru_CartPoleHard_model_Partial=False_SEED=0.pkl

๐Ÿ“ Example Training Script

# Standard PPO training
python train.py PPO --ENV_NAME CartPoleHard --TOTAL_TIMESTEPS 1000000

# Memory-enhanced training with LRU
python train.py PPO_RNN --MEMORY_TYPE lru --PARTIAL --ENV_NAME CartPoleHard

# Research model comparison
python train.py PPO_RNN --MEMORY_TYPE fart --PARTIAL --LR 1e-3

๐Ÿค Contributing

Contributions are welcome! Areas of interest:

  • New memory model implementations
  • Additional environment support
  • Performance optimizations
  • Documentation improvements

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

๐Ÿ™ Acknowledgments

  • Built with JAX and Equinox
  • Memory models inspired by recent research in efficient sequence modeling
  • Environment support via POPGym Arcade

๐Ÿ“š Citation

If you use this code in your research, please cite:

@article{wang2025popgym,
  title={POPGym Arcade: Parallel Pixelated POMDPs},
  author={Wang, Zekang and He, Zhe and Zhang, Borong and Toledo, Edan and Morad, Steven},
  journal={arXiv preprint arXiv:2503.01450},
  year={2025}
}

About

PPO reimplementation with Jax and Equinox

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages