-
Notifications
You must be signed in to change notification settings - Fork 2
Algorithms
Please see the paper main text and appendix for a detailed description of the FloWM algoritihm. algorithms/flowm_video.py contains the main implementation of the Flow Equivariant World Model algorithm for video data, including model initialization, and pytorch lightning function overrides. The model code for the 2d model is stored in algorithms/mem_wm/backbones/flowm/flowm_models_2d.py, and the model code for the 3d model is stored in algorithms/mem_wm/backbones/flowm/flowm_models_3d.py.
Please see the paper appendix for more information on the baseline algorithms. algorithms/dfot_video.py is the shared pytorch lightning algorithms file for both DFoT and DFoT-SSM. The model code for DFoT is based on CogVideoX's backbone (official code here), combined with the official code of DFoT here. Our implementation is stored in algorithms/mem_wm/backbones/cogv/cogv_transformers.py, and the diffusion implementation can be found at algorithms/mem_wm/diffusion/gaussian_diffusion.py.
The model code for DFoT-SSM is based on the official implementation shared by the author. Our implementation is stored in algorithms/mem_wm/backbones/cogv/block_ssm.py.
To add a new algorithm / model, see the examples in the pytorch lightning file flowm_video.py and pytorch model file flowm_models_3d.py. New configurations will be needed for a new algorithm. The model code is standard pytorch code and is initialized by the algorithm code. The standard practice is for the config options for the model to be passed through the algorithm branch in the hydra config.
The DFoT and DFoT-SSM code uses latent diffusion. For each dataset, we have trained a separate VAE to encode a latent space for that data distribution.
Diffusion works better with normalized data, so we compute the latent stats and normalize before passing the latents to the diffusion model. The latent stats are based on the output of the trained VAE, so they must be calculated and set separately for each dataset. This can be done by invoking the parallel_estimate_latent_stats.py script, for example:
python -m algorithms.vae.parallel_estimate_latent_stats dataset=blockworld dataset.latent.type=online algorithm=train_mae_vae experiment=video_latent_learning ckpt_map=defaultTo run this, make sure to set the algorithm properly and set the vae field properly in train_mae_vae.yaml (i.e. set the correct vae config for the one you just trained). Be sure to also set the pretrained_path properly as well. Then once you are done, update the corresponding latent_mean and latent_std fields in the configurations/algorithm/vae configuration.
Training the VAE can use the same code command structure. Make sure to implement the dataset loader first, then create a config for your vae under configurations/algorithm/vae, then fill in the remaining config options in the pretrain_vae_... yaml file. For example:
python -m main shortcode=exp/blockworld/pretrain_vae_blockworld +name=imagevae_train_blockworld dataset=blockworld algorithm=train_mae_vae experiment=video_latent_learning ckpt_map=defaultThe default is to load latents and process them online with the VAE (data goes through vae online to produce latents, which are then fed to the diffusion model during training). Depending on the server setup, training speed can be improved significantly if we precompute latents for our dataset. For example, follow:
python -m main shortcode=preprocess/preprocess_blockworld +name=preprocess_latent_blockworld algorithm=mae_vit_vae_preprocessor experiment=video_latent_preprocessing dataset=blockworld ckpt_map=defaultThis will output the latents to a new folder based on the latent.suffix specified in the dataset yaml. Then, to use it later, set latent.type=pre_sample during training.