Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/base_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ jobs:
- name: Set up Python
run: uv python install
- name: Install the project
run: uv sync --locked --all-extras --dev
run: uv sync --all-extras --dev
- name: Run tests
run: uv run pytest tests
run: uv run pytest tests
6 changes: 3 additions & 3 deletions .github/workflows/linters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ on:
paths:
- 'eb_jepa/**'
- 'examples/**'
- 'tests/'
- 'tests/'

jobs:
run-linters:
Expand All @@ -30,12 +30,12 @@ jobs:
- name: Set up Python
run: uv python install
- name: Install the project
run: uv sync --locked --all-extras --dev
run: uv sync --all-extras --dev
- name: Set lint paths
run: echo "lint_paths=eb_jepa examples tests" >> "$GITHUB_ENV"
- name: Run isort
run: |
uv run python -m isort $lint_paths --check
uv run python -m isort $lint_paths --check
- name: Run black
if: always()
run: |
Expand Down
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ venv/
ENV/
env.bak/
venv.bak/
uv.lock

# Spyder project settings
.spyderproject
Expand Down Expand Up @@ -220,10 +221,9 @@ wandb
.streamlit/secrets.toml

# example text multistep ignores
apps/example_text_multistep/data/fineweb10B/
apps/example_text_multistep/logs/
logs/
tests/logs/
checkpoints/
*.pth
*.npy
eb_jepa_ICLR/
CLAUDE.md
19 changes: 5 additions & 14 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
## paths
.ONESHELL:
.PHONY: help
.DEFAULT_GOAL := help

## print a help msg to display the comments
help:
help: ## Show this help message
@grep -hE '^[A-Za-z0-9_ \-]*?:.*##.*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'

USER := $(shell whoami)
PWD := $(shell pwd)
ROOT := $(shell cd ..; pwd)
run_image_jepa: ## Run the image JEPA example
uv run python examples/image_jepa/main.py

run_example_video_jepa:
uv run python apps/example_video_jepa/main.py

run_example_video_jepa_chunked_vc:
run_video_jepa: ## Run the video JEPA example
uv run python examples/video_jepa/main.py

run_example_ac_video_jepa:
run_ac_video_jepa: ## Run the action-conditioned video JEPA example
uv run python examples/ac_video_jepa/main.py

run_example_text_multistep:
uv run torchrun --standalone --nproc_per_node=8 examples/text_multistep/main.py
180 changes: 155 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,55 +1,185 @@
# Energy Based Joint Embedding Predictive Architectures
![EB JEPA](docs/teaser.png)

An open source library and tutorial aimed at learning representations for prediction and planning using joint embedding predictive arhictectures.
Examples include learning image (a), video (b), and action conidtioned video (c) predictive models representations, as well as planning with them (d).
![EB JEPA](docs/archi-schema-eb-jepa.png)

Each example is (almost) self-contained and training takes up to few hours on a single GPU card.
An open source library and tutorial for learning representations for prediction and planning using joint embedding predictive architectures.

### [Image Representations](examples/image_jepa/README.md)
> Each example is (almost) self-contained and training takes up to a few hours on a single GPU card.

This example demonstrates learning self-supervised representations from unlabeled images on CIFAR 10, and evaluated on image classification.
![Moving MNIST](examples/image_jepa/assets/arch_figure.png)
---

### [Predictive Video Representations](examples/video_jepa/README.md)
![Moving MNIST](examples/video_jepa/assets/viz.png)
## 📚 Examples

### [Image JEPA](examples/image_jepa/README.md)

Self-supervised representations from unlabeled images on CIFAR-10, evaluated on classification.

![Image JEPA Architecture](examples/image_jepa/assets/arch_figure.png)

A model is trained to predict the next image representation in a sequence
### [Video JEPA](examples/video_jepa/README.md)

Predict next image representation in a sequence.

![Moving MNIST](examples/video_jepa/assets/viz.png)

### [Action Conditioned Prediction and Planning](examples/ac_video_jepa/README.md)
### [AC Video JEPA](examples/ac_video_jepa/README.md)

This example demonstrates a Joint Embedding Predictive Architecture (JEPA) for action-conditioned world modeling in the Two Rooms environment. The model learns to predict future states based on current observations and actions. These representations enable planning towards a goal observation embedding.
JEPA for world modeling + planning in Two Rooms environment.

| Planning Episode | Task Definition |
|------------------|-----------------|
| <img src="examples/ac_video_jepa/assets/top_randw_agent_steps_succ.gif" alt="Successful planning episode" width="155" /> | <img src="examples/ac_video_jepa/assets/top_randw_state.png" alt="Episode task definition" width="300" /> |
| *Successful planning episode* | *Episode task definition: from init to goal state* |
| *Successful planning episode* | *From init to goal state* |

## Installation
---

We use [uv](https://docs.astral.sh/uv/guides/projects/) package manager to install and maintain packages. Once you have [installed uv](pip install --upgrade uv/), run the following to create a new virtual environment.
## 🚀 Installation

We use [uv](https://docs.astral.sh/uv/guides/projects/) for package management.

```bash
# Install dependencies
uv sync
# Option 1: Activate virtual environment
source .venv/bin/activate
python main.py
# Option 2: Run directly with uv
uv run python main.py
```
If you need conda-specific packages, you can use **Conda + uv**

This will create a virtual environment within the project folder at `.venv/`. To activate this environment, run `source .venv/bin/activate`.
```bash
# Create conda environment with Python 3.12
conda create -n eb_jepa python=3.12 -y
conda activate eb_jepa
# Install package in editable mode with dev dependencies (pytest, black, isort)
uv pip install -e . --group dev
```

Alternatively, if you don't want to run activate everytime, you can just prepend `uv run` before your python scripts:
Add these to your `~/.bashrc` for persistent configuration.

```bash
uv run python main.py
# Required for SLURM jobs to find datasets
export EBJEPA_DSETS=/path/to/eb_jepa/datasets
# Optional: Directory for checkpoints and logs
export EBJEPA_CKPTS=/path/to/checkpoints
```



---

## 🏋️ Training

### Quick Start

```bash
# Local training
python -m examples.{image_jepa,video_jepa,ac_video_jepa}.main
```
> Our default configs are tuned for H100 GPUs. With older GPUs (e.g., A100, V100), you may need to reduce batch size to fit in memory.

### 📂 Folder Structure

All experiments use a unified folder structure:

```
checkpoints/
└── {example_name}/
├── dev_2026-01-16_00-10/ # Single/local runs (dev_ prefix)
│ └── {exp_name}_seed1/
├── sweep_2026-01-16_00-10/ # Auto-named 3-seed sweep
│ ├── {exp_name}_seed1/
│ ├── {exp_name}_seed1000/
│ └── {exp_name}_seed10000/
└── sweep_my_experiment/ # Custom-named sweep
└── ...
```

## Running test cases
`{exp_name}` encodes key hyperparameters to avoid folder collisions, e.g.:
- **image_jepa**: `resnet_vicreg_proj_bs256_ep300_ph2048_po2048_std1.0_cov80.0`
- **video_jepa**: `resnet_bs64_lr0.001_std10.0_cov100.0`
- **ac_video_jepa**: `impala_cov8_std16_simt12_idm1`

Libraries added to eb-jepa [must have their own test cases](/tests/). To run the tests: `uv run pytest tests/`
<details>
<summary><span style="font-size: 1.17em; font-weight: bold;">🖥️ SLURM Launcher (optional)</span></summary>

| Command | Description |
|---------|-------------|
| `--example {name}` | Choose: `image_jepa`, `video_jepa`, `ac_video_jepa` |
| `--fname {path}` | Run the sweep specified in the config at `{path}` |
| `--single` | Launch single job (dev mode) |
| `--sweep {name}` | Custom sweep name |
| `--array-parallelism {N}` | Limits the maximum number of concurrent jobs to `N` |
| `--full-sweep` | Full hyperparameter sweep from config |
| `--use-wandb-sweep` | Enable wandb sweep UI |

```bash
# 3 seeds with wandb averaging (recommended)
python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml

# Custom sweep name
python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml --sweep my_experiment

# Single job
python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml --single

# Full hyperparameter sweep
python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml --full-sweep

# With wandb sweep UI for hyperparameter analysis
python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml --use-wandb-sweep
```

Replace `image_jepa` with `ac_video_jepa` or `video_jepa` for other examples.

**Full Sweep Configuration:** The `--full-sweep` flag reads the `sweep.param_grid` section from the example's YAML config file (e.g., `examples/image_jepa/cfgs/default.yaml`). Without this flag, only a 3-seed sweep is launched. To customize sweep parameters, edit the `sweep` section in the config:

```yaml
# Example: examples/image_jepa/cfgs/default.yaml
sweep:
param_grid:
loss.cov_coeff: [0.1, 1.0, 10.0, 100.0]
loss.std_coeff: [1.0, 10.0]
meta.seed: [1, 1000, 10000]
```

### Wandb Seed Averaging

Runs with the same hyperparameters but different seeds share the same wandb run name, enabling automatic averaging:

1. Go to wandb web UI → Runs table
2. Click **"Group by"** → select **"Name"**
→ Groups runs with identical hyperparameters (different seeds) together

To filter runs from a specific sweep:
3. Click **"Filter"** → **"Group"** → select your sweep name

For detailed wandb sweep analysis (parallel coordinates, hyperparameter importance):
1. Use `--use-wandb-sweep` flag when launching
2. Go to wandb web UI → left pane → **"Sweeps"** → click your sweep name

**SLURM Configuration:** To customize SLURM parameters (partition, account, memory, etc.), edit the `SLURM_DEFAULTS` dictionary at the top of `examples/launch_sbatch.py`.

</details>

## 🧪 Running test cases

Libraries added to eb_jepa [must have their own test cases](/tests/). To run the tests:

```bash
# With uv sync installation
uv run pytest tests/
# With conda + uv installation (no .venv created)
pytest tests/
```

## Development
## 👩‍💻 Development

- The uv package comes with `black` and `isort`, which must be run before adding any file in this repo. The continous integration will check the linting of the PRs and new files.
- Every PR should be reviewed by folks tagged at [CODEOWNERS](docs/CODEOWNERS).
- The dev dependencies include `black` and `isort`, which must be run before contributing to this repo.

## 📄 License

## License
EB JEPA is Apache licensed, as found in the [LICENSE](LICENSE.md) file.
EB-JEPA is Apache licensed. See [LICENSE](LICENSE.md).
Empty file removed __init__.py
Empty file.
Binary file added docs/archi-schema-eb-jepa.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 0 additions & 3 deletions eb_jepa/__init__.py

This file was deleted.

Loading