- Project with Fangchen Liu, Kuan Fang
- Summary: We investigated whether a pretrained high-level vision model could generate useful intermediate subgoals to help robots solve long-horizon manipulation tasks more reliably than end-to-end policies. We tried several architectures (MAE, VQ-VAE, CVAE), but ultimately found GPT-style autoregressive transformers to work best for predicting future states and goals from visual input. A major focus was prompt tuning, which means keeping a pretrained model fixed and adapting it to a new task by learning a small set of extra “prompt” tokens instead of changing all the weights. In our setting, the idea was to let the prompt specify what task the robot should do (e.g., which objects to manipulate and in what order), while the model used its pretrained knowledge to generate appropriate subgoals. In practice, this mostly failed because our pretrained models saw single-stage tasks with explicit goal images, while downstream tasks were multi-stage and required reasoning about order and structure, making prompt-only adaptation too weak. In contrast, full finetuning or encoder-only finetuning worked reliably and produced diverse, semantically meaningful subgoals that transferred across tasks. The main takeaway is that strong pretraining plus flexible finetuning is far more effective than strict prompt tuning for robotic control, and that autoregressive transformers are a strong backbone for this role. Promising future directions include reinforcement learning over latent subgoal variables, language- or task-conditioned subgoal generation, cross-domain transfer from human videos, and tighter integration with real-robot policies.
- Offline RL baseline
- Trained GC-IQL / GC-BC with scripted subgoals; scripted subgoals significantly improved performance on some tasks but task2 remained very poor.
- Bottlenecks traced to dataset mismatch (OOD objects/positions) and brittle low-level policy.
- Dataset iteration
- Created task-specific datasets (esp. task2), increased dataset size (3k → 15k → 50k).
- Edited task geometry (e.g. spam box position/shape/color).
- VQ-VAE
- Implemented and swept VQ-VAE sizes; pretrained encoders used for GC-IQL.
- VQ-VAE reconstructions reasonable in sim, but concerns about scaling to real data.
- Subgoal evaluation tooling
- Implemented subgoal switching during evaluation and visualization (scripted subgoals overlaid in videos).
Observation: Scripted subgoals help, but learning good subgoals is the real bottleneck, especially for long-horizon tasks.
- MAE / VQ-MAE
- Encoded full datasets with VQ; trained MAE/MaskDP-style models.
- Prompt tuning implemented (update only prompt tokens during task-specific finetuning).
- Prompt tuning results
- Works poorly in practice; strong signs of overfitting and pretrain/finetune mismatch.
- New long-horizon tasks (kitchen, cooking, table setting) are multi-stage, unlike pretraining tasks.
- Architecture changes
- Designed faster VQ architecture (3× speedup).
- Eventually ditched VQ-VAE and trained MAE directly with MSE loss:
-
- Better gripper/object structure
- − Blur remains, especially on real data
-
- Real data
- Began real-world data collection; trained MAE on Bridge data for offline debugging.
Observation:
Prompt tuning is brittle when task specification differs structurally from pretraining (single-stage vs multi-stage). Full finetuning works much better.
- GPT-style autoregressive transformer
- Switched from MAE/BERT-style to GPT-style causal transformer for state/goal prediction.
- Large gains in image quality (less blur, better gripper reconstruction) on sim + real.
- Finetuning vs prompt tuning
- Full finetuning + goal image → works extremely well.
- Removing goal image → performance drops.
- Prompt/prefix tuning alone → consistently poor.
- Prompt tuning analysis
- Tried multi-prompt tokens, auxiliary MSE losses pushing prompts toward subgoals, partial layer finetuning.
- Still fails → attributed to fundamental mismatch between pretraining (single goal) and finetuning (multi-stage ordering).
- Explored alternatives
- Read PARROT, GCPC, DualMind.
- Identified RL on latent z (PARROT-style) as promising future direction.
- MAE refinements
- Tried GAN discriminator, ViT encoder swaps, VQ-VAE variants → limited gains.
- SL: emphasized that good VQ-VAEs are hard; reconstruction alone isn’t enough.
Observation:
GPT-style model is a clear win. Prompt tuning is not essential; use pretraining + full finetuning flexibly.
- Prompt tuning abandoned as core goal
- Extensive hacks failed in hierarchical/state-prediction models.
- SL: focus on effective model usage, not zero-shot prompting.
- New environment
- Began port to LIBERO environment.
- Considered task-ID conditioning instead of goal images.
- Broader inspiration
- Looked at image/video generation (MUSE, diffusion editing) for object movement/editing.
- Model training
- ViT + CVAE goal image predictors:
- β = 0 → good recon, no diversity
- β > 0 → diversity improves, blur increases
- ViT + CVAE goal image predictors:
- Multi-step predictor
- Trained on heavily downsampled trajectories (3–4 steps).
- Reconstructions imperfect but samples diverse and semantically meaningful (different tasks in same scene).
- MAE vs GPT decoder: no clear winner yet.
- Prompt tuning (again)
- Naive prompt tuning (replace encoder with prompt) fails.
- Partial encoder finetuning or alternative task representations suggested.
- Bridge + robot integration
- Coordinating with Homer/Andre to test models on real robot.
- Plan to test whether existing GCBC policy can handle generated subgoals.
- Language conditioning
- Considering adding language during pretraining or finetuning.
- Bridge generalization experiment
- Pretrained on all Bridge tasks except 3 held-out tasks.
- Finetuned encoder only on each held-out task (decoder frozen).
- Results: very strong, comparable to full finetuning.
- KL regularization
- Tested β = pretraining value and β = 0.
- Hypothesis (from Fangchen): small nonzero β may keep z in-distribution without hurting performance.
- Planned experiments
- Data efficiency: number of trajectories needed vs full finetuning.
- Train GCBC from scratch using model-generated goals.
Key Result:
Freezing decoder + encoder finetuning (“pseudo-prompt tuning”) works well, unlike classic prompt tuning.
- Based on Ilya's JaxRL repo
conda create -n jaxrl python=3.10
conda activate jaxrl
pip install -e .
pip install -r requirements.txt
For GPU:
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For TPU
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
See the Jax github page for more details on installing Jax.
Check out contributing.md
Sample command:
python3 jaxrl_minimal2/experiments/procgen_cql.py --max_steps=3000000 --n_lev=1000 --eval_interval=10000 "--config.network_kwargs.hidden_dims=(256, 256)" --wandb.exp_descriptor=cql_encoder_{encoder}_{impala_scaling} --config.cql_alpha=0.5 --encoder=impala --impala_scaling=4
WandB flags:
--wandb.project: Which project to log to--wandb.exp_prefix: Group name--wandb.exp_descriptor: Experiment identifier--wandb.entity(In case you want to log to a different team)
To see all available flags, run:
python3 experiments/procgen_cql.py --help