A powerful self-supervised learning framework for Joint-Embedding Predictive Architecture (JEPA)
Installation β’ Quick Start β’ Documentation β’ Examples β’ Contributing
JEPA (Joint-Embedding Predictive Architecture) is a cutting-edge self-supervised learning framework that learns rich representations by predicting parts of the input from other parts. This implementation provides a flexible, production-ready framework for training JEPA models across multiple modalities.
π§ Modular Design
- Flexible encoder-predictor architecture
- Support for any PyTorch model as encoder/predictor
- Easy to extend and customize for your specific needs
π Multi-Modal Support
- Computer Vision: Images, videos, medical imaging
- Natural Language Processing: Text, documents, code
- Time Series: Sequential data, forecasting, anomaly detection
- Audio: Speech, music, environmental sounds
- Multimodal: Vision-language, audio-visual learning
β‘ High Performance
- Mixed precision training (FP16/BF16)
- Native DistributedDataParallel (DDP) support
- Memory-efficient implementations
- Optimized for both research and production
π Comprehensive Logging
- Weights & Biases integration
- TensorBoard support
- Console logging with rich formatting
- Multi-backend logging system
ποΈ Production Ready
- CLI interface for easy deployment
- Flexible YAML configuration system
- Comprehensive testing suite
- Docker support and containerization
- Type hints throughout
JEPA follows a simple yet powerful architecture:
graph LR
A[Input Data] --> B[Context/Target Split]
B --> C[Encoder]
C --> D[Joint Embedding Space]
D --> E[Predictor]
E --> F[Target Prediction]
F --> G[Loss Computation]
The model learns by:
- Splitting input into context and target regions
- Encoding both context and target separately
- Predicting target embeddings from context embeddings
- Learning representations that capture meaningful relationships
pip install jepagit clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e .git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e ".[dev,docs]"docker pull dipsivenkatesh/jepa:latest
docker run -it dipsivenkatesh/jepa:latestimport torch
from torch.utils.data import DataLoader, TensorDataset
from jepa.models import JEPA
from jepa.models.encoder import Encoder
from jepa.models.predictor import Predictor
from jepa.trainer import create_trainer
# Toy dataset of (state_t, state_t1) pairs
dataset = TensorDataset(torch.randn(256, 16, 128), torch.randn(256, 16, 128))
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)
# Build model components
encoder = Encoder(hidden_dim=128)
predictor = Predictor(hidden_dim=128)
model = JEPA(encoder=encoder, predictor=predictor)
# Trainer with sensible defaults
trainer = create_trainer(model, learning_rate=3e-4, device="auto")
# Train for a couple of epochs
trainer.train(train_loader, num_epochs=2)
# Optional: stream metrics to Weights & Biases
trainer_ddp = create_trainer(
model,
learning_rate=3e-4,
device="auto",
logger="wandb",
logger_project="jepa-experiments",
logger_run_name="quickstart-run",
)
# Persist weights for downstream inference
model.save_pretrained("artifacts/jepa-small")
# Reload using the same model class
reloaded = JEPA.from_pretrained("artifacts/jepa-small", encoder=encoder, predictor=predictor)Launch multi-GPU jobs with PyTorch's launcher:
torchrun --nproc_per_node=4 scripts/train.py --config config/default_config.yamlInside your training script, enable DDP when you create the trainer:
trainer = create_trainer(
model,
distributed=True,
world_size=int(os.environ["WORLD_SIZE"]),
local_rank=int(os.environ.get("LOCAL_RANK", 0)),
)The trainer wraps the model in DistributedDataParallel, synchronizes losses, and restricts logging/checkpointing to rank zero automatically.
Use JEPAAction when actions influence the next state. Provide a state encoder, an action encoder, and a predictor that consumes the concatenated [z_t, a_t] embedding.
from jepa import JEPAAction
import torch.nn as nn
state_dim = 512
action_dim = 64
# Example encoders (replace with your own)
state_encoder = nn.Sequential(
nn.Flatten(),
nn.Linear(784, state_dim),
)
action_encoder = nn.Sequential(
nn.Linear(10, 128), nn.ReLU(),
nn.Linear(128, action_dim),
)
# Predictor takes [state_dim + action_dim] β state_dim
predictor = nn.Sequential(
nn.Linear(state_dim + action_dim, 512), nn.ReLU(),
nn.Linear(512, state_dim),
)
model = JEPAAction(state_encoder, action_encoder, predictor)# Train a model
jepa-train --config config/default_config.yaml
# Train with custom parameters
jepa-train --config config/vision_config.yaml \
--batch-size 64 \
--learning-rate 0.001 \
--num-epochs 100
# Evaluate a trained model
jepa-evaluate --config config/default_config.yaml \
--checkpoint checkpoints/best_model.pth
# Generate a configuration template
jepa-train --generate-config my_config.yaml
# Get help
jepa-train --helpJEPA uses YAML configuration files for easy experiment management:
# config/my_experiment.yaml
model:
encoder_type: "transformer"
encoder_dim: 768
predictor_type: "mlp"
predictor_hidden_dim: 2048
training:
batch_size: 32
learning_rate: 0.0001
num_epochs: 100
warmup_epochs: 10
data:
train_data_path: "data/train"
val_data_path: "data/val"
sequence_length: 16
logging:
wandb:
enabled: true
project: "jepa-experiments"
tensorboard:
enabled: true
log_dir: "./tb_logs"- Image Classification: Pre-train backbones for downstream tasks
- Object Detection: Learn robust visual representations
- Medical Imaging: Analyze medical scans and imagery
- Satellite Imagery: Process large-scale geographic data
- Language Models: Pre-train transformer architectures
- Document Understanding: Learn document-level representations
- Code Analysis: Understand code structure and semantics
- Cross-lingual Learning: Build multilingual representations
- Forecasting: Pre-train models for prediction tasks
- Anomaly Detection: Learn normal patterns in sequential data
- Financial Modeling: Analyze market trends and patterns
- IoT Sensors: Process sensor data streams
- Vision-Language: Combine images and text understanding
- Audio-Visual: Learn from synchronized audio and video
- Cross-Modal Retrieval: Search across different modalities
- Embodied AI: Integrate multiple sensor modalities
from jepa import JEPADataset, JEPATrainer, load_config
import torch.nn as nn
# Custom vision encoder
class VisionEncoder(nn.Module):
def __init__(self, input_dim=3, hidden_dim=512):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(input_dim, 64, 7, 2, 3),
nn.ReLU(),
nn.Conv2d(64, 128, 3, 2, 1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(128, hidden_dim)
)
def forward(self, x):
return self.conv_layers(x)
# Load config and customize
config = load_config("config/vision_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=VisionEncoder)
trainer.train()from transformers import AutoModel
from jepa import JEPA, JEPATrainer
# Use pre-trained transformer as encoder
class TransformerEncoder(nn.Module):
def __init__(self, model_name="bert-base-uncased"):
super().__init__()
self.transformer = AutoModel.from_pretrained(model_name)
def forward(self, input_ids, attention_mask=None):
outputs = self.transformer(input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state.mean(dim=1) # Pool over sequence
config = load_config("config/nlp_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=TransformerEncoder)
trainer.train()from jepa import create_dataset, JEPATrainer
# Create time series dataset
dataset = create_dataset(
data_path="data/timeseries.csv",
sequence_length=50,
prediction_length=10,
features=['sensor1', 'sensor2', 'sensor3']
)
config = load_config("config/timeseries_config.yaml")
trainer = JEPATrainer(config=config, train_dataset=dataset)
trainer.train()- Full Documentation - Complete API reference and guides
- Installation Guide - Detailed installation instructions
- Configuration Guide - How to configure your experiments
- Training Guide - Training best practices
- API Reference - Complete API documentation
git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install in development mode
pip install -e ".[dev,docs]"
# Install pre-commit hooks
pre-commit install# Run all tests
pytest
# Run with coverage
pytest --cov=jepa --cov-report=html
# Run specific test
pytest tests/test_model.py::test_jepa_forward# Format code
black jepa/
isort jepa/
# Type checking
mypy jepa/
# Linting
flake8 jepa/We welcome contributions! Please see our Contributing Guide for details.
- π Bug Reports: Submit detailed bug reports
- β¨ Feature Requests: Suggest new features or improvements
- π Documentation: Improve documentation and examples
- π§ Code: Submit pull requests with bug fixes or new features
- π― Use Cases: Share your JEPA applications and results
- Fork the repository
- Create a feature branch:
git checkout -b feature-name - Make your changes and add tests
- Ensure all tests pass:
pytest - Submit a pull request
If you use JEPA in your research, please cite:
@software{jepa2025,
title = {JEPA: Joint-Embedding Predictive Architecture Framework},
author = {Venkatesh, Dilip},
year = {2025},
url = {https://github.com/dipsivenkatesh/jepa},
version = {0.1.0}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Inspired by the original JEPA paper and Meta's research
- Built with PyTorch, Transformers, and other amazing open-source libraries
- Thanks to all contributors and users of the framework
- GitHub Issues: Report bugs or request features
- Documentation: Read the full documentation
- Discussions: Join community discussions
Steps to push latest version: rm -rf dist build .egg-info python -m build twine upload dist/
β Star this repo | π Read the docs | π Report issues
Built with β€οΈ by Dilip Venkatesh