A unified reinforcement learning framework featuring memory-augmented algorithms and POMDP environment implementations. This repository provides modular components for building, configuring, and running a variety of RL algorithms on classic and memory-intensive environments.
- 🤖 Memory-RL: JAX implementations of DQN, PPO (Discrete & Continuous), SAC (Discrete & Continuous), PQN, IPPO, R2D2, and their memory-augmented variants with burn-in support for recurrent networks.
- 📦 Pure JAX Episode Buffer: A fully JAX-native episode buffer implementation enabling efficient storage and sampling of complete episodes for recurrent training, with support for Prioritized Experience Replay.
- 🔁 Sequence Models: LSTM/GRU (via Flax), sLSTM/mLSTM, FFM/SHM, S5/LRU/Mamba/MinGRU, plus Self-Attention and Linear Attention blocks. GPT-2/GTrXL/xLSTM-style architectures are composed from these primitives (see
examples/architectures). - 🧬 Networks: MLP, CNN, and ViT encoders with support for RoPE and ALiBi positional embeddings, and Mixture of Experts (MoE) for horizontal scaling.
- 🎮 Environments: Support for Gymnax, PopJym, PopGym Arcade, Navix, Craftax, Brax, MuJoCo, gxm, XMiniGrid, and JaxMARL.
- 📊 Logging & Sweeps: Support for a CLI Dashboard, Weights & Biases, TensorboardX, and Neptune.
- 🔧 Easy to Extend: Clear directory structure for adding new networks, algorithms, or environments.
Install Memorax using pip:
pip install memoraxOr using uv:
uv add memoraxOptionally you can add support for CUDA with:
pip install memorax[cuda]Optional: Set up Weights & Biases for logging by logging in:
wandb loginRun a default DQN experiment on CartPole:
uv run examples/dqn_cartpole.pyimport jax
import optax
from memorax.algorithms import PPO, PPOConfig
from memorax.environments import environment
from memorax.networks import (
MLP, FFN, ALiBi, FeatureExtractor, GatedResidual, Network,
PreNorm, SegmentRecurrence, SelfAttention, Stack, heads,
)
env, env_params = environment.make("gymnax::CartPole-v1")
cfg = PPOConfig(
name="PPO-GTrXL",
num_envs=8,
num_eval_envs=16,
num_steps=128,
gamma=0.99,
gae_lambda=0.95,
num_minibatches=4,
update_epochs=4,
normalize_advantage=True,
clip_coef=0.2,
clip_vloss=True,
ent_coef=0.01,
vf_coef=0.5,
)
features, num_heads, num_layers = 64, 4, 2
feature_extractor = FeatureExtractor(observation_extractor=MLP(features=(features,)))
attention = GatedResidual(PreNorm(SegmentRecurrence(
SelfAttention(features, num_heads, context_length=128, positional_embedding=ALiBi(num_heads)),
memory_length=64, features=features,
)))
ffn = GatedResidual(PreNorm(FFN(features=features, expansion_factor=4)))
torso = Stack(blocks=(attention, ffn) * num_layers)
actor_network = Network(feature_extractor, torso, heads.Categorical(env.action_space(env_params).n))
critic_network = Network(feature_extractor, torso, heads.VNetwork())
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))
agent = PPO(cfg, env, env_params, actor_network, critic_network, optimizer, optimizer)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=10_000)memorax/
├─ examples/ # Small runnable scripts (e.g., DQN CartPole)
├─ memorax/
├─ algorithms/ # DQN, PPO, SAC, PQN, ...
├─ networks/ # MLP, CNN, ViT, RNN, heads, ...
├─ environments/ # Gymnax / PopGym / Brax / ...
├─ buffers/ # Custom flashbax buffers
├─ loggers/ # CLI, WandB, TensorBoardX integrations
└─ utils/
Memorax is designed to work alongside a growing suite of JAX-native tools focused on partial observability and memory. These projects provide the foundational architectures and benchmarks for modern memory-augmented RL:
- Memax: A library for efficient sequence and recurrent modeling in JAX. It provides unified interfaces for fast recurrent state resets and associative scans, serving as a powerful primitive for building memory architectures.
- Flashbax: The library powering Memorax's buffer system. It provides high-performance, JAX-native experience replay buffers optimized for sequence storage and prioritized sampling.
- Gymnax: The standard for JAX-native RL environments. Memorax provides seamless wrappers to run recurrent agents on these vectorized tasks.
- PopGym Arcade: A JAX-native suite of "pixel-perfect" POMDP environments. It features Atari-style games specifically designed to test long-term memory with hardware-accelerated rendering.
- PopJym: A fast, JAX-native implementation of the POPGym benchmark suite, providing a variety of classic POMDP tasks optimized for massive vectorization.
- Navix: Accelerated MiniGrid-style environments. These are excellent for testing spatial reasoning and navigation in partially observable grid worlds.
- XLand-MiniGrid: A high-throughput meta-RL environment suite that provides massive task diversity for testing agent generalization in POMDPs.
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
If you use Memory-RL for your work, please cite:
@software{memorax2025github,
title = {Memorax: A Unified Framework for Memory-Augmented Reinforcement Learning},
author = {Noah Farr},
year = {2025},
url = {https://github.com/memory-rl/memorax}
}
Special thanks to @huterguier for the valuable discussions and advice on the API design.
