Skip to content

PyTorch code and models for V-JEPA self-supervised learning from video.

License

Notifications You must be signed in to change notification settings

vchernetsky/jepa

Β 
Β 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

68 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

JEPA Framework

JEPA Logo PyPI version Python 3.8+ License: MIT Documentation

A powerful self-supervised learning framework for Joint-Embedding Predictive Architecture (JEPA)

Installation β€’ Quick Start β€’ Documentation β€’ Examples β€’ Contributing

πŸš€ Overview

JEPA (Joint-Embedding Predictive Architecture) is a cutting-edge self-supervised learning framework that learns rich representations by predicting parts of the input from other parts. This implementation provides a flexible, production-ready framework for training JEPA models across multiple modalities.

Key Features

πŸ”§ Modular Design

  • Flexible encoder-predictor architecture
  • Support for any PyTorch model as encoder/predictor
  • Easy to extend and customize for your specific needs

🌍 Multi-Modal Support

  • Computer Vision: Images, videos, medical imaging
  • Natural Language Processing: Text, documents, code
  • Time Series: Sequential data, forecasting, anomaly detection
  • Audio: Speech, music, environmental sounds
  • Multimodal: Vision-language, audio-visual learning

⚑ High Performance

  • Mixed precision training (FP16/BF16)
  • Native DistributedDataParallel (DDP) support
  • Memory-efficient implementations
  • Optimized for both research and production

πŸ“Š Comprehensive Logging

  • Weights & Biases integration
  • TensorBoard support
  • Console logging with rich formatting
  • Multi-backend logging system

πŸŽ›οΈ Production Ready

  • CLI interface for easy deployment
  • Flexible YAML configuration system
  • Comprehensive testing suite
  • Docker support and containerization
  • Type hints throughout

πŸ—οΈ Architecture

JEPA follows a simple yet powerful architecture:

graph LR
    A[Input Data] --> B[Context/Target Split]
    B --> C[Encoder]
    C --> D[Joint Embedding Space]
    D --> E[Predictor]
    E --> F[Target Prediction]
    F --> G[Loss Computation]
Loading

The model learns by:

  1. Splitting input into context and target regions
  2. Encoding both context and target separately
  3. Predicting target embeddings from context embeddings
  4. Learning representations that capture meaningful relationships

πŸ“¦ Installation

From PyPI (Recommended)

pip install jepa

From Source

git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e .

Development Installation

git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e ".[dev,docs]"

Docker

docker pull dipsivenkatesh/jepa:latest
docker run -it dipsivenkatesh/jepa:latest

πŸš€ Quick Start

Python API

import torch
from torch.utils.data import DataLoader, TensorDataset

from jepa.models import JEPA
from jepa.models.encoder import Encoder
from jepa.models.predictor import Predictor
from jepa.trainer import create_trainer

# Toy dataset of (state_t, state_t1) pairs
dataset = TensorDataset(torch.randn(256, 16, 128), torch.randn(256, 16, 128))
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Build model components
encoder = Encoder(hidden_dim=128)
predictor = Predictor(hidden_dim=128)
model = JEPA(encoder=encoder, predictor=predictor)

# Trainer with sensible defaults
trainer = create_trainer(model, learning_rate=3e-4, device="auto")

# Train for a couple of epochs
trainer.train(train_loader, num_epochs=2)

# Optional: stream metrics to Weights & Biases
trainer_ddp = create_trainer(
    model,
    learning_rate=3e-4,
    device="auto",
    logger="wandb",
    logger_project="jepa-experiments",
    logger_run_name="quickstart-run",
)

# Persist weights for downstream inference
model.save_pretrained("artifacts/jepa-small")

# Reload using the same model class
reloaded = JEPA.from_pretrained("artifacts/jepa-small", encoder=encoder, predictor=predictor)

Distributed Training (DDP)

Launch multi-GPU jobs with PyTorch's launcher:

torchrun --nproc_per_node=4 scripts/train.py --config config/default_config.yaml

Inside your training script, enable DDP when you create the trainer:

trainer = create_trainer(
    model,
    distributed=True,
    world_size=int(os.environ["WORLD_SIZE"]),
    local_rank=int(os.environ.get("LOCAL_RANK", 0)),
)

The trainer wraps the model in DistributedDataParallel, synchronizes losses, and restricts logging/checkpointing to rank zero automatically.

Action-Conditioned Variant

Use JEPAAction when actions influence the next state. Provide a state encoder, an action encoder, and a predictor that consumes the concatenated [z_t, a_t] embedding.

from jepa import JEPAAction
import torch.nn as nn

state_dim = 512
action_dim = 64

# Example encoders (replace with your own)
state_encoder = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, state_dim),
)
action_encoder = nn.Sequential(
    nn.Linear(10, 128), nn.ReLU(),
    nn.Linear(128, action_dim),
)

# Predictor takes [state_dim + action_dim] β†’ state_dim
predictor = nn.Sequential(
    nn.Linear(state_dim + action_dim, 512), nn.ReLU(),
    nn.Linear(512, state_dim),
)

model = JEPAAction(state_encoder, action_encoder, predictor)

Command Line Interface

# Train a model
jepa-train --config config/default_config.yaml

# Train with custom parameters
jepa-train --config config/vision_config.yaml \
           --batch-size 64 \
           --learning-rate 0.001 \
           --num-epochs 100

# Evaluate a trained model
jepa-evaluate --config config/default_config.yaml \
              --checkpoint checkpoints/best_model.pth

# Generate a configuration template
jepa-train --generate-config my_config.yaml

# Get help
jepa-train --help

Configuration

JEPA uses YAML configuration files for easy experiment management:

# config/my_experiment.yaml
model:
  encoder_type: "transformer"
  encoder_dim: 768
  predictor_type: "mlp"
  predictor_hidden_dim: 2048

training:
  batch_size: 32
  learning_rate: 0.0001
  num_epochs: 100
  warmup_epochs: 10

data:
  train_data_path: "data/train"
  val_data_path: "data/val"
  sequence_length: 16

logging:
  wandb:
    enabled: true
    project: "jepa-experiments"
  tensorboard:
    enabled: true
    log_dir: "./tb_logs"

🎯 Use Cases

Computer Vision

  • Image Classification: Pre-train backbones for downstream tasks
  • Object Detection: Learn robust visual representations
  • Medical Imaging: Analyze medical scans and imagery
  • Satellite Imagery: Process large-scale geographic data

Natural Language Processing

  • Language Models: Pre-train transformer architectures
  • Document Understanding: Learn document-level representations
  • Code Analysis: Understand code structure and semantics
  • Cross-lingual Learning: Build multilingual representations

Time Series Analysis

  • Forecasting: Pre-train models for prediction tasks
  • Anomaly Detection: Learn normal patterns in sequential data
  • Financial Modeling: Analyze market trends and patterns
  • IoT Sensors: Process sensor data streams

Multimodal Learning

  • Vision-Language: Combine images and text understanding
  • Audio-Visual: Learn from synchronized audio and video
  • Cross-Modal Retrieval: Search across different modalities
  • Embodied AI: Integrate multiple sensor modalities

πŸ“š Examples

Vision Example

from jepa import JEPADataset, JEPATrainer, load_config
import torch.nn as nn

# Custom vision encoder
class VisionEncoder(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=512):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_dim, 64, 7, 2, 3),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, hidden_dim)
        )
    
    def forward(self, x):
        return self.conv_layers(x)

# Load config and customize
config = load_config("config/vision_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=VisionEncoder)
trainer.train()

NLP Example

from transformers import AutoModel
from jepa import JEPA, JEPATrainer

# Use pre-trained transformer as encoder
class TransformerEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_name)
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state.mean(dim=1)  # Pool over sequence

config = load_config("config/nlp_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=TransformerEncoder)
trainer.train()

Time Series Example

from jepa import create_dataset, JEPATrainer

# Create time series dataset
dataset = create_dataset(
    data_path="data/timeseries.csv",
    sequence_length=50,
    prediction_length=10,
    features=['sensor1', 'sensor2', 'sensor3']
)

config = load_config("config/timeseries_config.yaml")
trainer = JEPATrainer(config=config, train_dataset=dataset)
trainer.train()

πŸ“– Documentation

πŸ”§ Development

Setting up Development Environment

git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install in development mode
pip install -e ".[dev,docs]"

# Install pre-commit hooks
pre-commit install

Running Tests

# Run all tests
pytest

# Run with coverage
pytest --cov=jepa --cov-report=html

# Run specific test
pytest tests/test_model.py::test_jepa_forward

Code Quality

# Format code
black jepa/
isort jepa/

# Type checking
mypy jepa/

# Linting
flake8 jepa/

🀝 Contributing

We welcome contributions! Please see our Contributing Guide for details.

Ways to Contribute

  • πŸ› Bug Reports: Submit detailed bug reports
  • ✨ Feature Requests: Suggest new features or improvements
  • πŸ“– Documentation: Improve documentation and examples
  • πŸ”§ Code: Submit pull requests with bug fixes or new features
  • 🎯 Use Cases: Share your JEPA applications and results

Development Workflow

  1. Fork the repository
  2. Create a feature branch: git checkout -b feature-name
  3. Make your changes and add tests
  4. Ensure all tests pass: pytest
  5. Submit a pull request

πŸ“„ Citation

If you use JEPA in your research, please cite:

@software{jepa2025,
  title = {JEPA: Joint-Embedding Predictive Architecture Framework},
  author = {Venkatesh, Dilip},
  year = {2025},
  url = {https://github.com/dipsivenkatesh/jepa},
  version = {0.1.0}
}

πŸ“ License

This project is licensed under the MIT License - see the LICENSE file for details.

πŸ™ Acknowledgments

  • Inspired by the original JEPA paper and Meta's research
  • Built with PyTorch, Transformers, and other amazing open-source libraries
  • Thanks to all contributors and users of the framework

πŸ“ž Support

Steps to push latest version: rm -rf dist build .egg-info python -m build twine upload dist/


About

PyTorch code and models for V-JEPA self-supervised learning from video.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.7%
  • Shell 0.3%