Fork of DreamerV3 for studying spatial representations learned by world model agents in Crafter. Includes tools for trajectory recording, position decoding, layer-wise analysis, spatial tuning curves, and dream decoding.
Based on Hafner et al. 2025, "Mastering Diverse Domains through World Models."
Requires Python 3.11+. Tested on Linux (SLURM clusters with A100/H100 GPUs).
pip install -U -r requirements.txtFor CPU-only machines (e.g. local Mac):
pip install jax==0.4.35 jaxlib==0.4.35 chex==0.1.87 optax==0.2.3JAX platform: Default is
cuda. On CPU-only machines, always pass--jax.platform cpuor JAX will fail to initialize.
The typical workflow is: train → eval trajectory → analysis (decoding, tuning curves, plotting).
All analysis scripts log provenance to run_info.json in their output directory via dreamerv3/run_info.py. Each entry records the timestamp, git SHA, full command line, SLURM job ID, and all arguments. Training runs also save hyperparams.txt and config.yaml to the logdir.
python dreamerv3/main.py \
--configs crafter_small size25m \
--logdir ./logdir/my_run \
--env.crafter.random_spawn True \
--env.crafter.fixed_seed FalseAll config options are in dreamerv3/configs.yaml. Use --configs to stack presets (applied left to right), then override individual flags. Key Crafter env flags:
| Flag | Default | Description |
|---|---|---|
--env.crafter.fixed_seed |
True |
Same world every episode (resets crafter's internal episode counter) |
--env.crafter.random_spawn |
False |
Relocate player to random walkable tile each episode |
--env.crafter.egocentric_view |
0 |
N×N ego-centered view (odd int, e.g. 7); 0 = disabled |
--env.crafter.disable_mobs |
False |
Remove hostile mobs (zombies, skeletons) |
--env.crafter.area |
[64,64] |
World size in tiles (crafter_small preset uses [32,32]) |
--seed |
0 |
Global seed. Env seed derived as hash((seed, env_index)) |
Seed mechanics: Crafter generates worlds via hash((env_seed, episode_number)). With fixed_seed=True, the episode counter is reset to 0 before each reset(), so every episode gets the same world. With fixed_seed=False, episode numbers increment naturally (1, 2, 3, ...) giving a different world each episode. Different --seed values produce different environments even with fixed_seed=True.
Size presets: size1m, size12m, size25m, size50m, size100m, size200m, size400m.
Records agent positions, world model activations (all layers), images, and achievements per timestep. Saves per-episode .pkl files plus all_episodes.pkl with metadata.
python dreamerv3/main.py \
--configs crafter_small size25m \
--logdir ./logdir/my_run \
--script eval_trajectory \
--run.from_checkpoint ./logdir/my_run/ckpt/TIMESTAMP_DIR \
--eval_trajectory.num_episodes 100 \
--eval_trajectory.save_path ./logdir/my_run/trajectories \
--seed 42 --jax.platform cpuCheckpoint is auto-detected from logdir/ckpt/latest if --run.from_checkpoint is omitted. Env settings (egocentric_view, random_spawn, etc.) are inherited from the training run's saved config.yaml.
GPU eval: Drop
--jax.platform cpuon GPU nodes for 50-100x speedup (important for 400M param models).
MPLBACKEND=Agg python dreamerv3/plot_trajectories.py \
--data ./logdir/my_run/trajectories \
--plot all \
--save ./logdir/my_run/plotsPlot types: trajectories, heatmap, activation, spatial, world, fullworld, animate, animate_world, worldview, world_only, all.
Animation/worldview options: --egocentric_view N, --view_half N, --window_tiles N, --step_ms N, --mp4 (MP4 instead of GIF, requires ffmpeg).
Trains pRNN-style linear classification decoders to predict (x,y) from latent representations using cross-validation.
MPLBACKEND=Agg python dreamerv3/decode_position.py \
--data ./logdir/my_run/trajectories \
--save ./logdir/my_run/decoder_results \
--n_jobs -1Key options:
--repr {deter,stoch,combined,all}— which representation to decode--save_model— save fitted decoders (needed for dream decode)--n_jobs N— parallel workers (-1= all CPUs)--device {cpu,cuda}— torch device for classification--n_iters N— classification training iterations (default 5000)--patience N— early stopping patience (default 500, 0 = disable)--min_bbox N— filter episodes with bounding-box area < N tiles² (0 = no filter)
Outputs: classification_*.png, probmap_*.png, occupancy_vs_error_*.png, decode_results.pkl.
Scans every recorded layer (encoder, dynamics, policy, value) and produces a per-layer comparison boxplot.
# Holdout mode (fastest — train on one set, eval on another):
MPLBACKEND=Agg python dreamerv3/decode_position.py \
--data ./trajectories_train --test_data ./trajectories_test \
--save ./decoder_results \
--mode layers --device cuda --n_jobs -1
# CV mode (single trajectory set, 5-fold KFold):
MPLBACKEND=Agg python dreamerv3/decode_position.py \
--data ./trajectories --save ./decoder_results \
--mode layers --device cuda --n_jobs -1 --holdout_frac 0
# Save trained decoders for reuse:
MPLBACKEND=Agg python dreamerv3/decode_position.py \
--data ./trajectories --save ./decoder_results \
--mode layers --device cuda --n_jobs -1 --save_model
# Eval saved decoders on new data (no training):
MPLBACKEND=Agg python dreamerv3/decode_position.py \
--data ./new_trajectories --save ./new_results \
--mode layers --from_model ./decoder_results/layer_decodersLayer-specific options:
--max_samples N— subsample timesteps (default 10000)--n_cv_folds N— KFold folds (default 5)--holdout_frac F— auto train/test split fraction (default 0.2, use 0 for CV mode)--test_data PATH— separate held-out trajectory dir (overrides holdout_frac)--resume PATH— resume from partial checkpoint--from_model PATH— eval-only from savedlayer_decoders/dir
Outputs: layer_comparison.png, layer_decode_results.pkl, layer_decode_checkpoint.pkl.
Applies a pretrained classification position decoder to imagined (dream) rollouts to test spatial coherence.
# Step 1: Save decoder model (if not done already)
MPLBACKEND=Agg python dreamerv3/decode_position.py \
--data ./logdir/my_run/trajectories \
--save ./logdir/my_run/decoder_results \
--save_model
# Step 2: Run dream decode
python dreamerv3/main.py \
--configs crafter_small size25m \
--logdir ./logdir/my_run \
--script dream_decode \
--run.from_checkpoint ./logdir/my_run/ckpt/TIMESTAMP_DIR \
--dream_decode.decoder_model ./logdir/my_run/decoder_results/classifier_deter.pkl \
--dream_decode.save_path ./logdir/my_run/dream_resultsOptions: --dream_decode.num_batches N, --dream_decode.num_episodes N.
Outputs: dream_trajectories_world.png, dream_probmap_*.png, dream_vs_real.png, dream_results.pkl.
MPLBACKEND=Agg python dreamerv3/plot_training.py \
--logdir ./logdir/my_run \
--save ./logdir/my_run/plots \
--smooth 50Reads scores.jsonl and metrics.jsonl from logdir. Produces training_progress.png with panels for episode score, cumulative reward, Crafter score (geometric mean of achievement success rates), and per-achievement unlock rates.
Options: --no_achievements (skip per-achievement panel), --no_losses (skip loss/reward/value panels).
Classifies neurons into spatial cell types (place, border, HD, etc.) using pynapple. Analyzes all recorded layers.
MPLBACKEND=Agg python dreamerv3/tuning_curve.py \
--data ./logdir/my_run/trajectories \
--save ./logdir/my_run/tuning_results \
--n_jobs -1Options:
--layers LAYER [...]— filter to specific layers (default: all)--test_data PATH— held-out trajectories for EV reliability--max_neurons N— subsample large layers (0 = all)--no_hd— skip head direction analysis--no_plots— skip plot generation--interactive— show interactive SI vs EV scatter--min_bbox N— filter episodes by bounding-box area--SI_thresh,--EV_thresh,--EV_unthresh,--HD_thresh— classification thresholds
Outputs: tuning_results.pkl, {layer}_si_ev_scatter.png, {layer}_cell_types.png, {layer}_example_tuning_curves.png, layer_summary.png.
Interactive viewer (from precomputed results):
python dreamerv3/tuning_curve.py --from_pkl tuning_results.pkl
python dreamerv3/tuning_curve.py --from_pkl tuning_results.pkl --layers dyn/deterShell scripts for cluster submission are in the repo root:
| Script | Purpose | GPU | Key settings |
|---|---|---|---|
run_Crafter.sh |
Train model | H100 | Parametrized env/size/activation/wd |
run_Trajectory.sh |
Eval trajectories | CPU | Matches training hyperparams |
run_Loop.sh |
Full pipeline (train → eval → all analysis) | A100 | End-to-end |
run_Decoding.sh |
Standard position decoding | GPU | method, repr, device |
run_LayerDecoding.sh |
Layer-wise decoding | A100 | mode, ridge/classification, holdout |
run_Plotting.sh |
Plot trajectories + training | CPU | plot type, animation settings |
run_Tuning.sh |
Tuning curve analysis | A100 | layers, thresholds, n_jobs |
Training scripts save hyperparams.txt to the logdir with all SLURM-level settings.
Every analysis script (decode_position.py, plot_trajectories.py, plot_training.py, tuning_curve.py) appends a JSON entry to <save_dir>/run_info.json via dreamerv3/run_info.py. Each entry contains:
stage— pipeline stage nametimestamp— ISO timestampgit_sha— short git SHA of HEADcommand— full command lineslurm_job_id— SLURM job ID (if running under SLURM)args— all parsed argumentsoutputs— list of output files producedextra— additional metadata (varies by stage)
Training runs additionally save:
config.yaml— full resolved DreamerV3 config (all hyperparams)hyperparams.txt— SLURM script-level settings (when usingrun_Crafter.shorrun_Loop.sh)
- All config options are in
dreamerv3/configs.yaml. Override as CLI flags. - Stack config presets:
--configs crafter_small size50m(applied left to right). - The
debugpreset reduces network size and batch size for fast iteration. - Use
MPLBACKEND=Aggfor headless plotting on cluster nodes. - Checkpoint path: pass the timestamped directory (e.g.
ckpt/20260129T183613F519148), not thelatestfile. - If you get
Too many leaves for PyTreeDef, the checkpoint doesn't match the current config. - CUDA errors: scroll up — often caused by OOM or JAX/CUDA version mismatch. Try
--batch_size 1. - To resume training, rerun the same command with the same
--logdir.
PYTHONPATH=. python embodied/tests/test_crafter_world.pyTests fixed_seed determinism, walkable spawn validation, and world consistency.
@article{hafner2025dreamerv3,
title={Mastering diverse control tasks through world models},
author={Hafner, Danijar and Pasukonis, Jurgis and Ba, Jimmy and Lillicrap, Timothy},
journal={Nature},
pages={1--7},
year={2025},
publisher={Nature Publishing Group}
}