(blog)
jepax is a JAX/Equinox implementation of Joint-Embedding Predictive Architecture (JEPA) models and related self-supervised learning methods. The focus is on straightforward implementations that allow for quick experimentation with new regularizers, losses, or further downstream tasks. We are actively building this library: PRs welcome!
We focused on 1-to-1 configs, losses, and logging with the original PyTorch implementation. Below is a reproduction of IJEPA-B with data parallelization on 8xA100.
Training loss and linear probe accuracy for IJEPA-B trained for 300 epochs on 8xA100 80GB.
git clone https://github.com/sugolov/jepax.git
cd jepax
pip install -e .A straightforward way to download Imagenet1k is with HuggingFace. The below script will cache it in a target directory for your dataloader.
export HF_TOKEN=hf_xxxxxxxxxxxxx
python -m jepax.data.download_imagenet.py --data_dir ~/your/data/dirLaunching training for any of the IJEPA-B/L/H model sizes in configs/:
python -m jepax.train.train_ijepa \
--config configs/ijepa_b.yaml \
--data_dir ~/your/data/dir \
--save_dir ~/your/save/dir \
--save_interval 10 \
--exp_name ijepa-b \
--use_wandb \
--wandb_project ijepa - Reproduce ImageNet results from IJEPA
- RCDM Visualization
- Model sharding across GPUs
- Benchmarking mixed precision training (MPX)
- LeJEPA
- V-JEPA
- Update tests
- Pre-trained model weights
- Benchmarks against PyTorch implementation
- Awesome JEPA (list of JEPA papers/code)
- I-JEPA (paper) (repo)
- DINOv2 (paper) (repo)
- V-JEPA (paper) (repo)
Other JAX libraries: Awesome JAX.
