diff --git a/.github/workflows/base_tests.yaml b/.github/workflows/base_tests.yaml index 6c67bf0..75833ae 100644 --- a/.github/workflows/base_tests.yaml +++ b/.github/workflows/base_tests.yaml @@ -17,6 +17,6 @@ jobs: - name: Set up Python run: uv python install - name: Install the project - run: uv sync --locked --all-extras --dev + run: uv sync --all-extras --dev - name: Run tests - run: uv run pytest tests + run: uv run pytest tests diff --git a/.github/workflows/linters.yaml b/.github/workflows/linters.yaml index 924f94e..eec6344 100644 --- a/.github/workflows/linters.yaml +++ b/.github/workflows/linters.yaml @@ -14,7 +14,7 @@ on: paths: - 'eb_jepa/**' - 'examples/**' - - 'tests/' + - 'tests/' jobs: run-linters: @@ -30,12 +30,12 @@ jobs: - name: Set up Python run: uv python install - name: Install the project - run: uv sync --locked --all-extras --dev + run: uv sync --all-extras --dev - name: Set lint paths run: echo "lint_paths=eb_jepa examples tests" >> "$GITHUB_ENV" - name: Run isort run: | - uv run python -m isort $lint_paths --check + uv run python -m isort $lint_paths --check - name: Run black if: always() run: | diff --git a/.gitignore b/.gitignore index 1283615..56a1258 100644 --- a/.gitignore +++ b/.gitignore @@ -151,6 +151,7 @@ venv/ ENV/ env.bak/ venv.bak/ +uv.lock # Spyder project settings .spyderproject @@ -220,10 +221,9 @@ wandb .streamlit/secrets.toml # example text multistep ignores -apps/example_text_multistep/data/fineweb10B/ -apps/example_text_multistep/logs/ -logs/ +tests/logs/ checkpoints/ *.pth *.npy eb_jepa_ICLR/ +CLAUDE.md diff --git a/Makefile b/Makefile index 41df7b1..ea193d6 100644 --- a/Makefile +++ b/Makefile @@ -1,24 +1,15 @@ -## paths .ONESHELL: .PHONY: help .DEFAULT_GOAL := help -## print a help msg to display the comments -help: +help: ## Show this help message @grep -hE '^[A-Za-z0-9_ \-]*?:.*##.*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -USER := $(shell whoami) -PWD := $(shell pwd) -ROOT := $(shell cd ..; pwd) +run_image_jepa: ## Run the image JEPA example + uv run python examples/image_jepa/main.py -run_example_video_jepa: - uv run python apps/example_video_jepa/main.py - -run_example_video_jepa_chunked_vc: +run_video_jepa: ## Run the video JEPA example uv run python examples/video_jepa/main.py -run_example_ac_video_jepa: +run_ac_video_jepa: ## Run the action-conditioned video JEPA example uv run python examples/ac_video_jepa/main.py - -run_example_text_multistep: - uv run torchrun --standalone --nproc_per_node=8 examples/text_multistep/main.py diff --git a/README.md b/README.md index 85058eb..9a30679 100644 --- a/README.md +++ b/README.md @@ -1,55 +1,185 @@ # Energy Based Joint Embedding Predictive Architectures -![EB JEPA](docs/teaser.png) -An open source library and tutorial aimed at learning representations for prediction and planning using joint embedding predictive arhictectures. -Examples include learning image (a), video (b), and action conidtioned video (c) predictive models representations, as well as planning with them (d). +![EB JEPA](docs/archi-schema-eb-jepa.png) -Each example is (almost) self-contained and training takes up to few hours on a single GPU card. +An open source library and tutorial for learning representations for prediction and planning using joint embedding predictive architectures. -### [Image Representations](examples/image_jepa/README.md) +> Each example is (almost) self-contained and training takes up to a few hours on a single GPU card. -This example demonstrates learning self-supervised representations from unlabeled images on CIFAR 10, and evaluated on image classification. -![Moving MNIST](examples/image_jepa/assets/arch_figure.png) +--- -### [Predictive Video Representations](examples/video_jepa/README.md) -![Moving MNIST](examples/video_jepa/assets/viz.png) +## πŸ“š Examples + +### [Image JEPA](examples/image_jepa/README.md) + +Self-supervised representations from unlabeled images on CIFAR-10, evaluated on classification. + +![Image JEPA Architecture](examples/image_jepa/assets/arch_figure.png) -A model is trained to predict the next image representation in a sequence +### [Video JEPA](examples/video_jepa/README.md) + +Predict next image representation in a sequence. + +![Moving MNIST](examples/video_jepa/assets/viz.png) -### [Action Conditioned Prediction and Planning](examples/ac_video_jepa/README.md) +### [AC Video JEPA](examples/ac_video_jepa/README.md) -This example demonstrates a Joint Embedding Predictive Architecture (JEPA) for action-conditioned world modeling in the Two Rooms environment. The model learns to predict future states based on current observations and actions. These representations enable planning towards a goal observation embedding. +JEPA for world modeling + planning in Two Rooms environment. | Planning Episode | Task Definition | |------------------|-----------------| | Successful planning episode | Episode task definition | -| *Successful planning episode* | *Episode task definition: from init to goal state* | +| *Successful planning episode* | *From init to goal state* | -## Installation +--- -We use [uv](https://docs.astral.sh/uv/guides/projects/) package manager to install and maintain packages. Once you have [installed uv](pip install --upgrade uv/), run the following to create a new virtual environment. +## πŸš€ Installation + +We use [uv](https://docs.astral.sh/uv/guides/projects/) for package management. ```bash +# Install dependencies uv sync +# Option 1: Activate virtual environment +source .venv/bin/activate +python main.py +# Option 2: Run directly with uv +uv run python main.py ``` +If you need conda-specific packages, you can use **Conda + uv** -This will create a virtual environment within the project folder at `.venv/`. To activate this environment, run `source .venv/bin/activate`. +```bash +# Create conda environment with Python 3.12 +conda create -n eb_jepa python=3.12 -y +conda activate eb_jepa +# Install package in editable mode with dev dependencies (pytest, black, isort) +uv pip install -e . --group dev +``` -Alternatively, if you don't want to run activate everytime, you can just prepend `uv run` before your python scripts: +Add these to your `~/.bashrc` for persistent configuration. ```bash -uv run python main.py +# Required for SLURM jobs to find datasets +export EBJEPA_DSETS=/path/to/eb_jepa/datasets +# Optional: Directory for checkpoints and logs +export EBJEPA_CKPTS=/path/to/checkpoints +``` + + + +--- + +## πŸ‹οΈ Training + +### Quick Start + +```bash +# Local training +python -m examples.{image_jepa,video_jepa,ac_video_jepa}.main +``` +> Our default configs are tuned for H100 GPUs. With older GPUs (e.g., A100, V100), you may need to reduce batch size to fit in memory. + +### πŸ“‚ Folder Structure + +All experiments use a unified folder structure: + +``` +checkpoints/ +└── {example_name}/ + β”œβ”€β”€ dev_2026-01-16_00-10/ # Single/local runs (dev_ prefix) + β”‚ └── {exp_name}_seed1/ + β”‚ + β”œβ”€β”€ sweep_2026-01-16_00-10/ # Auto-named 3-seed sweep + β”‚ β”œβ”€β”€ {exp_name}_seed1/ + β”‚ β”œβ”€β”€ {exp_name}_seed1000/ + β”‚ └── {exp_name}_seed10000/ + β”‚ + └── sweep_my_experiment/ # Custom-named sweep + └── ... ``` -## Running test cases +`{exp_name}` encodes key hyperparameters to avoid folder collisions, e.g.: +- **image_jepa**: `resnet_vicreg_proj_bs256_ep300_ph2048_po2048_std1.0_cov80.0` +- **video_jepa**: `resnet_bs64_lr0.001_std10.0_cov100.0` +- **ac_video_jepa**: `impala_cov8_std16_simt12_idm1` -Libraries added to eb-jepa [must have their own test cases](/tests/). To run the tests: `uv run pytest tests/` +
+πŸ–₯️ SLURM Launcher (optional) + +| Command | Description | +|---------|-------------| +| `--example {name}` | Choose: `image_jepa`, `video_jepa`, `ac_video_jepa` | +| `--fname {path}` | Run the sweep specified in the config at `{path}` | +| `--single` | Launch single job (dev mode) | +| `--sweep {name}` | Custom sweep name | +| `--array-parallelism {N}` | Limits the maximum number of concurrent jobs to `N` | +| `--full-sweep` | Full hyperparameter sweep from config | +| `--use-wandb-sweep` | Enable wandb sweep UI | + +```bash +# 3 seeds with wandb averaging (recommended) +python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml + +# Custom sweep name +python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml --sweep my_experiment + +# Single job +python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml --single + +# Full hyperparameter sweep +python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml --full-sweep + +# With wandb sweep UI for hyperparameter analysis +python -m examples.launch_sbatch --example image_jepa --fname examples/image_jepa/cfgs/default.yaml --use-wandb-sweep +``` + +Replace `image_jepa` with `ac_video_jepa` or `video_jepa` for other examples. + +**Full Sweep Configuration:** The `--full-sweep` flag reads the `sweep.param_grid` section from the example's YAML config file (e.g., `examples/image_jepa/cfgs/default.yaml`). Without this flag, only a 3-seed sweep is launched. To customize sweep parameters, edit the `sweep` section in the config: + +```yaml +# Example: examples/image_jepa/cfgs/default.yaml +sweep: + param_grid: + loss.cov_coeff: [0.1, 1.0, 10.0, 100.0] + loss.std_coeff: [1.0, 10.0] + meta.seed: [1, 1000, 10000] +``` + +### Wandb Seed Averaging + +Runs with the same hyperparameters but different seeds share the same wandb run name, enabling automatic averaging: + +1. Go to wandb web UI β†’ Runs table +2. Click **"Group by"** β†’ select **"Name"** + β†’ Groups runs with identical hyperparameters (different seeds) together + +To filter runs from a specific sweep: +3. Click **"Filter"** β†’ **"Group"** β†’ select your sweep name + +For detailed wandb sweep analysis (parallel coordinates, hyperparameter importance): +1. Use `--use-wandb-sweep` flag when launching +2. Go to wandb web UI β†’ left pane β†’ **"Sweeps"** β†’ click your sweep name + +**SLURM Configuration:** To customize SLURM parameters (partition, account, memory, etc.), edit the `SLURM_DEFAULTS` dictionary at the top of `examples/launch_sbatch.py`. + +
+ +## πŸ§ͺ Running test cases + +Libraries added to eb_jepa [must have their own test cases](/tests/). To run the tests: + +```bash +# With uv sync installation +uv run pytest tests/ +# With conda + uv installation (no .venv created) +pytest tests/ +``` -## Development +## πŸ‘©β€πŸ’» Development -- The uv package comes with `black` and `isort`, which must be run before adding any file in this repo. The continous integration will check the linting of the PRs and new files. -- Every PR should be reviewed by folks tagged at [CODEOWNERS](docs/CODEOWNERS). +- The dev dependencies include `black` and `isort`, which must be run before contributing to this repo. +## πŸ“„ License -## License -EB JEPA is Apache licensed, as found in the [LICENSE](LICENSE.md) file. +EB-JEPA is Apache licensed. See [LICENSE](LICENSE.md). diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/docs/archi-schema-eb-jepa.png b/docs/archi-schema-eb-jepa.png new file mode 100644 index 0000000..dc1f25d Binary files /dev/null and b/docs/archi-schema-eb-jepa.png differ diff --git a/eb_jepa/__init__.py b/eb_jepa/__init__.py deleted file mode 100644 index ba850f2..0000000 --- a/eb_jepa/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Energy-Based JEPA package - -__version__ = "0.1.0" diff --git a/eb_jepa/architectures.py b/eb_jepa/architectures.py index 43dce0b..78d9aab 100644 --- a/eb_jepa/architectures.py +++ b/eb_jepa/architectures.py @@ -7,12 +7,10 @@ from eb_jepa.nn_utils import TemporalBatchMixin, init_module_weights -###################################################### -# Basic architectural modules - -# a simple 3D convnet with 2 layers. class conv3d2(nn.Sequential): + """Simple 3D convnet with 2 layers.""" + def __init__(self, in_d, h_d, out_d, tk, ts, sk, ss, pad): super(conv3d2, self).__init__( nn.Conv3d( @@ -38,6 +36,8 @@ def __init__(self, in_d, h_d, out_d, tk, ts, sk, ss, pad): class ResidualBlock(nn.Module): + """Standard residual block with skip connection.""" + def __init__(self, in_channels, out_channels, stride=1): super(ResidualBlock, self).__init__() @@ -77,7 +77,7 @@ def forward(self, x): class ResNet5(TemporalBatchMixin, nn.Module): """ A lightweight ResNet with 5 layers (2 blocks). - Supports both 4D (B,C,H,W) and 5D (B,C,T,H,W) inputs via TemporalBatchMixin. + Supports both 4D [B, C, H, W] and 5D [B, C, T, H, W] inputs via TemporalBatchMixin. """ def __init__(self, in_d, h_d, out_d, s1=1, s2=1, s3=1, avg_pool=False): @@ -105,6 +105,8 @@ def _forward(self, x): class SimplePredictor(nn.Module): + """Wrapper that concatenates states and actions channel-wise before prediction.""" + def __init__(self, predictor, context_length): super().__init__() self.predictor = predictor @@ -120,8 +122,8 @@ class StateOnlyPredictor(SimplePredictor): def forward(self, x, a): # action not used on purpose - prev_state = x[:, :, :-1] # (B, C, T-1, H, W) - next_state = x[:, :, 1:] # (B, C, T-1, H, W) + prev_state = x[:, :, :-1] # [B, C, T-1, H, W] + next_state = x[:, :, 1:] # [B, C, T-1, H, W] combined_xa = torch.cat((prev_state, next_state), dim=1) return self.predictor(combined_xa) @@ -130,7 +132,7 @@ class ResUNet(TemporalBatchMixin, nn.Module): """ A small UNet with residual encoder blocks and transposed-conv upsampling. Channels scale like h, 2h, 4h, 8h. Output keeps the input HxW. - Supports both 4D (B,C,H,W) and 5D (B,C,T,H,W) inputs via TemporalBatchMixin. + Supports both 4D [B, C, H, W] and 5D [B, C, T, H, W] inputs via TemporalBatchMixin. """ def __init__(self, in_d, h_d, out_d, is_rnn=False): @@ -203,6 +205,8 @@ def _forward(self, x): class Projector(nn.Module): + """MLP projector built from a spec string like '256-512-128'.""" + def __init__(self, mlp_spec): super().__init__() layers = [] @@ -220,14 +224,15 @@ def forward(self, x): class DetHead(nn.Module): + """Detection head that pools features and predicts binary maps.""" def __init__(self, in_d, h_d, out_d): super().__init__() self.head = nn.Sequential(conv3d2(in_d, h_d, out_d, 1, 1, 3, 1, "same")) self.apply(init_module_weights) - # x: output of predictor (or autoregressive) def forward(self, x): + """Forward pass on predictor output of shape (B, C, T, H, W).""" # (Batch, Feature, Time, Height, Width) # [8, 8, T, 8, 8] x = [F.adaptive_avg_pool2d(x[:, :, t], (8, 8)) for t in range(x.shape[2])] @@ -360,11 +365,13 @@ def __init__( def forward(self, x): """ - x: (bs, ch, t, w, h) - out: (bs, dim, t, 1, 1) + Args: + x: [B, C, T, H, W] + Returns: + out: [B, D, T, 1, 1] """ - # (bs, ch, t, w, h) --> (t, bs, ch, w, h) + # [B, C, T, H, W] --> [T, B, C, H, W] ( _, _, @@ -403,9 +410,10 @@ def forward(self, x): class RNNPredictor(nn.Module): + """GRU-based predictor for single-step state propagation.""" + def __init__( self, - # parent inputs hidden_size: int = 512, action_dim: Optional[int] = 2, num_layers: int = 1, @@ -427,16 +435,17 @@ def __init__( def forward(self, state, action): """ - Propagate one step forward - Parameters: - state: (bs, dim, 1, 1, 1) - action: (bs, a_dim, 1) - Output: - Output: next_state (bs, dim, 1, 1, 1) + Propagate one step forward. + + Args: + state: [B, D, 1, 1, 1] + action: [B, A, 1] + Returns: + next_state: [B, D, 1, 1, 1] """ # This only does one step - rnn_state = state.flatten(1, 4).unsqueeze(0).contiguous() # (1, bs, dim) - rnn_input = action.squeeze(-1).unsqueeze(0).contiguous() # (1, bs, a_dim) + rnn_state = state.flatten(1, 4).unsqueeze(0).contiguous() # [1, B, D] + rnn_input = action.squeeze(-1).unsqueeze(0).contiguous() # [1, B, A] next_state, _ = self.rnn(rnn_input, rnn_state) diff --git a/eb_jepa/datasets/__init__.py b/eb_jepa/datasets/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/eb_jepa/datasets/moving_mnist.py b/eb_jepa/datasets/moving_mnist.py index 2597df9..0eb5082 100644 --- a/eb_jepa/datasets/moving_mnist.py +++ b/eb_jepa/datasets/moving_mnist.py @@ -1,14 +1,21 @@ import os +from pathlib import Path import numpy as np import torch from torch.utils.data import Dataset -FILENAME = "datasets/mnist_test_seq.npy" +# Use environment variable or fall back to path relative to __file__ +# This allows the original data location to be preserved when running from copied code folder +_DEFAULT_DATASETS_DIR = Path(__file__).parent.parent.parent.absolute() / "datasets" +_DATASETS_DIR = Path(os.environ.get("EBJEPA_DSETS", str(_DEFAULT_DATASETS_DIR))) +FILENAME = str(_DATASETS_DIR / "mnist_test_seq.npy") URL = "https://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy" def load_or_download(filename: str, url: str): + # Ensure datasets directory exists + os.makedirs(os.path.dirname(filename), exist_ok=True) if not os.path.exists(filename): print(f"File '{filename}' not found. Downloading from {url}...") try: @@ -42,12 +49,10 @@ class MovingMNIST(Dataset): def __init__(self, split=None): """ Args: - transform (callable, optional): Optional transform to be applied on a sample. split (str): train or val Returns: - video (torch.Tensor) (C, T, H, W): greyscale video frames - context (torch.Tensor) (C, T, H, W): past greyscale video frames as context + video: [C, T, H, W] - greyscale video frames """ load_or_download(FILENAME, URL) self.data_path = FILENAME @@ -82,9 +87,8 @@ def __init__(self, transform=None, split=None, map_size=8): map_size (int): size of map to predict positions over Returns: - video (torch.Tensor) (C, T, H, W): greyscale video frames - context (torch.Tensor) (C, T, H, W): past greyscale video frames as context - digit_location (torch.Tensor) (T, map_size, map_size): Coarse binary heatmap for digit locations + video: [C, T, H, W] - greyscale video frames + digit_location: [T, map_size, map_size] - Coarse binary heatmap for digit locations """ super().__init__(split) self.transform = transform @@ -109,13 +113,11 @@ def __getitem__(self, idx): if __name__ == "__main__": - dset = MovingMNIST() + dset = MovingMNIST(split="val") instance = dset[10] print(f"{instance['video'].shape = }") - print(f"{instance['context'].shape = }") - dset = MovingMNISTDet() + dset = MovingMNISTDet(split="val") instance = dset[10] print(f"{instance['video'].shape = }") - print(f"{instance['context'].shape = }") print(f"{instance['digit_location'].shape = }") diff --git a/eb_jepa/datasets/two_rooms/__init__.py b/eb_jepa/datasets/two_rooms/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/eb_jepa/datasets/two_rooms/data_config.yaml b/eb_jepa/datasets/two_rooms/data_config.yaml index a748478..0ee362d 100644 --- a/eb_jepa/datasets/two_rooms/data_config.yaml +++ b/eb_jepa/datasets/two_rooms/data_config.yaml @@ -1,15 +1,22 @@ +# Two-rooms environment dataset configuration. +# These defaults are merged with overrides from train.yaml or eval.yaml. + +# Environment physics action_noise: 1 action_angle_noise: 0.2 action_step_mean: 1.0 action_step_std: 0.4 action_lower_bd: 0.2 action_upper_bd: 1.8 -batch_size: 64 -device: cpu + +# Visual appearance dot_std: 1.3 +img_size: 65 + +# Wall/door layout border_wall_loc: 5 fix_wall_batch_k: null -fix_wall: true # false to randomize both wall and door loc +fix_wall: false # false to randomize both wall and door loc fix_door_location: 18 fix_wall_location: 32 exclude_wall_train: '' @@ -21,17 +28,27 @@ door_padding: 10 wall_width: 3 door_space: 4 num_train_layouts: -1 + +# Trajectory sampling cross_wall_rate: 0.35 expert_cross_wall_rate: 0 -wall_bump_rate: 0. # 0 by default +wall_bump_rate: 0. dup_traj_rate: 0. -img_size: 65 max_step: 1 sample_length: 17 # 17 in PLDM | 90 to eval long rollout n_steps: 91 n_steps_reduce_factor: 1 -size: 100000 # 188235 * 17 ~ 3.2M frames +repeat_actions: 1 + +# Dataset size +size: 100000 val_size: 10000 train: true -repeat_actions: 1 normalize: true + +# Dataloader defaults +batch_size: 64 +num_workers: 0 +pin_mem: false +persistent_workers: false +device: cpu diff --git a/eb_jepa/datasets/two_rooms/dot_dataset.py b/eb_jepa/datasets/two_rooms/dot_dataset.py index a664ac3..7ba75c5 100644 --- a/eb_jepa/datasets/two_rooms/dot_dataset.py +++ b/eb_jepa/datasets/two_rooms/dot_dataset.py @@ -129,8 +129,6 @@ def generate_state(self, wall_locs=None, door_locs=None, size=None): + self.padding ) - ## TODO: move wall logic to wall.py - left_walls = wall_locs - self.config.wall_width // 2 right_walls = wall_locs + self.config.wall_width // 2 @@ -197,7 +195,6 @@ def generate_multistep_sample( sample = self.generate_transitions( start_location, actions, bias_angle, walls=walls ) - # TODO: fix trajectories, change wall positions return sample diff --git a/eb_jepa/datasets/two_rooms/wall_dataset.py b/eb_jepa/datasets/two_rooms/wall_dataset.py index 32031b3..ab3847c 100644 --- a/eb_jepa/datasets/two_rooms/wall_dataset.py +++ b/eb_jepa/datasets/two_rooms/wall_dataset.py @@ -683,7 +683,7 @@ def generate_state_and_actions( # cw_count = math.ceil(self.config.batch_size * self.config.cross_wall_rate) cw_count = np.random.rand() < self.config.cross_wall_rate if cw_count: - (cw_locations, cw_actions, _) = ( + cw_locations, cw_actions, _ = ( self.generate_cross_wall_state_and_actions( wall_locs=wall_locs[:cw_count], door_locs=door_locs[:cw_count], @@ -701,7 +701,7 @@ def generate_state_and_actions( bump_count = min(bump_count, 1 - modified_count) if bump_count > 0: - (bump_locations, bump_actions) = ( + bump_locations, bump_actions = ( self.generate_wall_bump_state_and_actions( wall_locs=wall_locs[ modified_count : modified_count + bump_count @@ -1036,7 +1036,7 @@ def generate_transitions( states_with_walls = self.normalizer.normalize_state(states_with_walls) locations = self.normalizer.normalize_location(locations) - # make it conform to eb-jepa codebase: + # make it conform to eb_jepa codebase: # states: B, T, C, H, W --> B, F, T, H, W states_with_walls = states_with_walls.permute(0, 2, 1, 3, 4) # also rid of the last timestep to make timesteps aligned with action diff --git a/eb_jepa/datasets/utils.py b/eb_jepa/datasets/utils.py index deb1082..32a9407 100644 --- a/eb_jepa/datasets/utils.py +++ b/eb_jepa/datasets/utils.py @@ -1,16 +1,33 @@ +from pathlib import Path + import torch +import yaml from eb_jepa.datasets.two_rooms.utils import update_config_from_yaml from eb_jepa.datasets.two_rooms.wall_dataset import WallDataset, WallDatasetConfig +DATASETS_DIR = Path(__file__).parent + + +def load_env_data_config(env_name: str, overrides: dict = None) -> dict: + """Load base data config for an environment and apply overrides.""" + config_path = DATASETS_DIR / env_name / "data_config.yaml" + with open(config_path) as f: + base_config = yaml.safe_load(f) + if overrides: + base_config.update(overrides) + return base_config + def init_data(env_name, cfg_data=None, **kwargs): """Initialize data loaders for the specified environment. + Loads base config from eb_jepa/datasets/{env_name}/data_config.yaml + and merges with any overrides from cfg_data. + Args: env_name: Name of the environment (currently only "two_rooms" is supported). - cfg_data: Configuration dictionary for the dataset. - **kwargs: Additional keyword arguments (ignored). + cfg_data: Configuration overrides for the dataset. Returns: Tuple of (train_loader, val_loader, config). @@ -18,27 +35,33 @@ def init_data(env_name, cfg_data=None, **kwargs): if env_name != "two_rooms": raise ValueError(f"Unknown env: {env_name}. Only 'two_rooms' is supported.") - config = update_config_from_yaml(WallDatasetConfig, cfg_data) + merged_cfg = load_env_data_config(env_name, cfg_data) + config = update_config_from_yaml(WallDatasetConfig, merged_cfg) + + num_workers = merged_cfg.get("num_workers", 0) + pin_mem = merged_cfg.get("pin_mem", False) + persistent_workers = merged_cfg.get("persistent_workers", False) and num_workers > 0 + dset = WallDataset(config=config) loader = torch.utils.data.DataLoader( dset, batch_size=config.batch_size, shuffle=True, - num_workers=cfg_data.get("num_workers"), - pin_memory=cfg_data.get("pin_mem"), + num_workers=num_workers, + pin_memory=pin_mem, drop_last=True, - persistent_workers=cfg_data.get("persistent_workers") - and cfg_data.get("num_workers") > 0, + persistent_workers=persistent_workers, ) + val_dset = WallDataset(config=config) val_loader = torch.utils.data.DataLoader( val_dset, batch_size=4, shuffle=False, - num_workers=cfg_data.get("num_workers"), - pin_memory=cfg_data.get("pin_mem"), + num_workers=num_workers, + pin_memory=pin_mem, drop_last=True, - persistent_workers=cfg_data.get("persistent_workers") - and cfg_data.get("num_workers") > 0, + persistent_workers=persistent_workers, ) + return loader, val_loader, config diff --git a/eb_jepa/image_decoder.py b/eb_jepa/image_decoder.py index e318de3..77ccbe1 100644 --- a/eb_jepa/image_decoder.py +++ b/eb_jepa/image_decoder.py @@ -6,7 +6,7 @@ class ImageDecoder(TemporalBatchMixin, nn.Module): """ Simple 2D convolutional decoder for reconstructing images from representations. - Supports both 4D (B,C,H,W) and 5D (B,C,T,H,W) inputs via TemporalBatchMixin. + Supports both 4D [B, C, H, W] and 5D [B, C, T, H, W] inputs via TemporalBatchMixin. """ def __init__( diff --git a/eb_jepa/jepa.py b/eb_jepa/jepa.py index ff3ee37..56714e0 100644 --- a/eb_jepa/jepa.py +++ b/eb_jepa/jepa.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn @@ -7,17 +6,11 @@ logging = get_logger(__name__) -###################################################### -# a basic JEPA class. No learning abilities -# this is for planning and inference only. -# use the full JEPA class for SSL training. class JEPAbase(nn.Module): + """Base JEPA class for planning and inference only. Use JEPA subclass for training.""" + def __init__(self, encoder, aencoder, predictor): - """ - Action-Conditioned Joint Embedding Predictive Architecture world model. - This class has no training ability. - Use the JEPA subclass for training. - """ + """Initialize JEPAbase with encoder, action encoder, and predictor.""" super().__init__() # Observation Encoder self.encoder = encoder @@ -33,164 +26,198 @@ def save(self, file): def load(self, file): self.load_state_dict(torch.load(file), weights_only=False) - # just runs the encoder on a sequence of observations - # and returns the encoder output sequence @torch.no_grad() def encode(self, observations): + """Encode a sequence of observations and return the encoder output.""" return self.encoder(observations) - # inference producing single-step predictions over all - # elements in a sequence in parallel. - @torch.no_grad() - def infer(self, observations, actions): - return self.infern(observations, actions, nsteps=1)[0] - - @torch.no_grad() - def infern(self, observations, actions, nsteps=1): - # check number of steps. - state = self.encoder(observations) - context_length = self.predictor.context_length - if actions is not None: - actions = self.action_encoder(actions) - - predi = state - preds = [] - for _ in range(nsteps): - predi = self.predictor(predi, actions)[:, :, :-1] - preds.append(predi) - predi = torch.cat((state[:, :, :context_length], predi), dim=2) - - # compute total loss here - return preds - - # TODO: refactor predictor - # perform a multi-step prediction, auto-regressively in state space. - # Predictions are performed sequentially starting from a given context of - # observations.shape[2] frames, on actions.shape[2] action time steps. - # The last prediction timestep predi[:, :, -1:] is concatenated to the - # input state for the next prediction step. - # Optionally, a context window can be used to limit the number of past actions and frames - # the predictor can attend to. - # Returns predin: a concatention of groundtruth context embeddings and predictions. - @torch.no_grad() - def unrolln(self, observations, actions, nsteps, ctxt_window_time=1): - """ - Input shape: observations: (Batch, Feature, Time, Height, Width) OR (Batch, Time, Dim) - actions: (Batch, Feature, Time, Height, Width) - Output shape: predin: (Batch, Feature, Time, Height, Width) OR (Batch, Time, Dim) - """ - if nsteps > actions.size(2): - raise NameError( - "number of prediction steps larger than length of action sequence" - ) - # Input Encoding - state = self.encoder(observations) - # Action Encoding - actions = self.action_encoder(actions) - # prediction loop through steps. - # we just run the predictor as if it were a recurrent net. - if self.single_unroll: - curr_state = state[:, :, :1] - predin = curr_state - for i in range(nsteps): - curr_action = actions[:, :, i : i + 1] - curr_state = self.predictor(curr_state, curr_action) - predin = torch.cat([predin, curr_state], dim=2) - else: - predin = state - for i in range(nsteps): - predi = self.predictor( - predin[:, :, -ctxt_window_time:], - actions[:, :, max(0, i + 1 - ctxt_window_time) : i + 1], - ) - predi = predi[:, :, -1:] # take the last time step - predin = torch.cat([predin, predi], dim=2) - return predin - - -################################################################ -# A trainable JEPA class -# with a prediction loss and an anti-collapse regularizer loss class JEPA(JEPAbase): + """Trainable JEPA with prediction loss and anti-collapse regularizer.""" + def __init__(self, encoder, aencoder, predictor, regularizer, predcost): - """ - Action-Conditioned Joint Embedding Predictive Architecture world model. - Args: - """ + """Initialize JEPA with regularizer and prediction cost in addition to base components.""" super().__init__(encoder, aencoder, predictor) - # Anti-Collapse Regularizer self.regularizer = regularizer - # prediction loss self.predcost = predcost self.ploss = 0 self.rloss = 0 - # training forward with a multi-step auto-regressive prediction - # observations is a 5d tensor containing a sequence of observations - # (Batch, Feature, Time, Height, Width) - # actions is a 5d tensor containing a sequence of actions - # (Batch, Feature, Time, Height, Width) - def forwardn(self, observations, actions, nsteps=1): - # Input Encoding + @torch.no_grad() + def infer(self, observations, actions): + """Produce single-step predictions over all sequence elements in parallel.""" + preds, _ = self.unroll( + observations, + actions, + nsteps=1, + unroll_mode="parallel", + compute_loss=False, + return_all_steps=True, + ) + return preds[0] + + def unroll( + self, + observations, + actions, + nsteps=1, + unroll_mode="parallel", + ctxt_window_time=1, + compute_loss=True, + return_all_steps=False, + ): + """Unified multi-step prediction with optional loss computation. + + This function supports both training (with loss computation) and planning/inference + (without loss, just state prediction). + + Usage examples: + - Training video_jepa: unroll(x, None, nsteps, unroll_mode="parallel", compute_loss=True) + - Training ac_video_jepa with RNN: unroll(x, a, nsteps, unroll_mode="autoregressive", + ctxt_window_time=1, compute_loss=True) + - Planning with ac_video_jepa: unroll(x, a, nsteps, unroll_mode="autoregressive", + ctxt_window_time=k, compute_loss=False) + - Inference like infern(): unroll(x, a, nsteps, unroll_mode="parallel", + compute_loss=False, return_all_steps=True) + + Predictor behavior: + - unroll_mode="parallel" (Conv predictor, is_rnn=False): + Processes all timesteps in parallel. Uses predictor.context_length to + determine how many ground truth frames to re-feed at each iteration. + Output: [B, D, T, H', W'] (same length as input, predictions replace non-context). + Best for training with full ground truth trajectory available. + + - unroll_mode="autoregressive": + Step-by-step prediction with sliding window of ctxt_window_time states. + Each step: takes last ctxt_window_time states, predicts next, appends to sequence. + Output: [B, D, T_context + nsteps, H', W'] (context + predictions appended). + Best for planning/inference where future ground truth is not available. + Note: RNN predictors (is_rnn=True) are a special case with ctxt_window_time=1. + + Args: + observations: [B, C, T, H, W] - observation sequence + For training (compute_loss=True): full trajectory with ground truth + For planning (compute_loss=False): context frames only + actions: [B, A, T_actions] - action sequence, or None for state-only prediction + T_actions >= nsteps required for autoregressive mode + nsteps: number of prediction steps + unroll_mode: "parallel" or "autoregressive" + - "parallel": Process all timesteps, refeed GT context on left + - "autoregressive": Step-by-step, append predictions on right + ctxt_window_time: Context window size for autoregressive mode. + For RNN predictors (is_rnn=True), this is effectively 1. + compute_loss: Whether to compute losses (requires ground truth observations) + return_all_steps: If True, return list of predictions at each step (like infern). + If False, return only the final predicted states. + + Returns: + Tuple of (predicted_states, losses) where: + - If return_all_steps=False: + predicted_states: [B, D, T_out, H', W'] - final predicted state sequence + - If return_all_steps=True: + predicted_states: List[Tensor] of length nsteps, each [B, D, T_out, H', W'] + - losses: None if compute_loss=False, otherwise tuple of 5 elements: + (total_loss, reg_loss, reg_loss_unweighted, reg_loss_dict, pred_loss) + """ state = self.encoder(observations) - context_length = self.predictor.context_length + context_length = getattr(self.predictor, "context_length", 0) - # VC loss - rloss, rloss_unweight, rloss_dict = self.regularizer(state, actions) + # Compute regularization loss if needed + if compute_loss: + rloss, rloss_unweight, rloss_dict = self.regularizer(state, actions) + ploss = 0.0 + else: + rloss = rloss_unweight = rloss_dict = ploss = None + # Encode actions if actions is not None: - actions = self.action_encoder(actions) + actions_encoded = self.action_encoder(actions) + else: + actions_encoded = None + + # Collect all steps if requested + all_steps = [] if return_all_steps else None - predi = state - ploss = 0.0 - if self.single_unroll: - curr_state = state[:, :, :1] # (b, d, h, w) + # Parallel mode: process all timesteps at once, refeed GT context + if unroll_mode == "parallel": + predicted_states = state + for _ in range(nsteps): + # Predict all timesteps, discard last (no target for it) + predicted_states = self.predictor(predicted_states, actions_encoded)[ + :, :, :-1 + ] + # Collect step if requested + if return_all_steps: + all_steps.append(predicted_states) + # Refeed ground truth context on the left + predicted_states = torch.cat( + (state[:, :, :context_length], predicted_states), dim=2 + ) + if compute_loss: + ploss += self.predcost(state, predicted_states) / nsteps + + # Autoregressive mode: step-by-step with sliding window + # Note: RNN predictors (is_rnn=True) are a special case with ctxt_window_time=1 + elif unroll_mode == "autoregressive": + if actions is not None and nsteps > actions.size(2): + raise ValueError( + f"nsteps ({nsteps}) larger than action sequence length ({actions.size(2)})" + ) + # For RNN predictors, force ctxt_window_time=1 + effective_ctxt_window = 1 if self.single_unroll else ctxt_window_time + + predicted_states = state[:, :, :effective_ctxt_window] for i in range(nsteps): - curr_action = actions[:, :, i : i + 1] - curr_state = self.predictor(curr_state, curr_action) - ploss += self.predcost(curr_state, state[:, :, i + 1 : i + 2]) / nsteps + # Take last ctxt_window_time states + context_states = predicted_states[:, :, -effective_ctxt_window:] + # Take corresponding actions + if actions_encoded is not None: + context_actions = actions_encoded[ + :, :, max(0, i + 1 - effective_ctxt_window) : i + 1 + ] + else: + context_actions = None + # Predict and take only last timestep + pred_step = self.predictor(context_states, context_actions)[:, :, -1:] + # Append prediction to sequence + predicted_states = torch.cat([predicted_states, pred_step], dim=2) + # Collect step if requested + if return_all_steps: + all_steps.append(predicted_states.clone()) + if compute_loss: + ploss += ( + self.predcost(pred_step, state[:, :, i + 1 : i + 2]) / nsteps + ) else: - predi = state # (b, d, t, h, w) - # If predictor treats timesteps as batch dimension, reshaping b t c h w -> (b t) c h w, - # Then receptive field of predictor is one timestep only, so it is time-causal. - for _ in range(nsteps): - # Discard latest timestep prediction since there is no - # visual embedding target for it - predi = self.predictor(predi, actions)[:, :, :-1] - # Refeed 1st context_length grountruth embedding timesteps on the left - # as context for the next call to the predictor - predi = torch.cat((state[:, :, :context_length], predi), dim=2) - ploss += self.predcost(state, predi) / nsteps - - # compute total loss here - loss = rloss + ploss - return loss, rloss, rloss_unweight, rloss_dict, ploss - - -################################################################ -# a container that contains a JEPA and a trainable prediction head. -# the prediction head can be used as a decoder: -# simply set the targets are identical to the observations. + raise ValueError(f"Unknown unroll_mode: {unroll_mode}") + + # Compute total loss and return + if compute_loss: + loss = rloss + ploss + losses = (loss, rloss, rloss_unweight, rloss_dict, ploss) + else: + losses = None + + # Return all steps or just final state + if return_all_steps: + return all_steps, losses + else: + return predicted_states, losses + + class JEPAProbe(nn.Module): + """JEPA with a trainable prediction head. The JEPA encoder is kept fixed.""" + def __init__(self, jepa, head, hcost): - """ - A JEPA probe that includes a prediction head that - can be trained supervised. - The JEPA is kept fixed. - """ + """Initialize with a frozen JEPA, prediction head, and head loss function.""" super().__init__() self.jepa = jepa - # prediction head for a supervised task self.head = head - # loss for the prediction head self.hcost = hcost - # encode a sequence through the JEPA - # run the encoded state through the head and - # return the result @torch.no_grad() def infer(self, observations): + """Encode observations through JEPA and apply the prediction head.""" state = self.jepa.encode(observations) return self.head(state) @@ -202,11 +229,9 @@ def apply_head(self, embeddings): """ return self.head(embeddings) - # forward to train the head def forward(self, observations, targets): + """Forward pass for training the head (JEPA encoder gradients are detached).""" with torch.no_grad(): state = self.jepa.encode(observations) - # run the prediction head, but do not - # backprop through the JEPA encoder output = self.head(state.detach()) return self.hcost(output, targets) diff --git a/eb_jepa/losses.py b/eb_jepa/losses.py index 36d9453..7e0d772 100644 --- a/eb_jepa/losses.py +++ b/eb_jepa/losses.py @@ -1,39 +1,22 @@ -###################################################### -# Loss functions to prevent JEPAs from collapsing - import torch import torch.nn as nn import torch.nn.functional as F -################################################################ -# Utilities - -# simple square loss def sq_loss(x, y, reduction="mean"): + """Simple square loss (MSE).""" return nn.functional.mse_loss(x, y, reduction=reduction) -###################################################### -# prediction loss functions - - -# compute square loss between two sequences -# represented as BFTHW tensors. -# a shift is necessary for the predictor cost to -# make sure the architecture is strictly causal. def square_cost_seq(state, predi): + """Square loss between two [B, C, T, H, W] sequences.""" return sq_loss(state, predi) -# square loss between BFTHW sequences. -# This is used primarily for the prediction loss module. class SquareLossSeq(nn.Module): + """Square loss over a sequence [B, C, T, H, W] (feature dim at dim 1).""" + def __init__(self, proj=None): - """ - Square cost over a sequence represented - as BFTHW. Assumes the T dimension is at dim 2 - """ super().__init__() self.proj = nn.Identity() if proj is None else proj @@ -43,137 +26,31 @@ def forward(self, state, predi): return square_cost_seq(state, predi) -###################################################### -# Variance-Covariance loss on a batch of samples - - -# VC cost function -# input x is a 2D tensor Samples*Features (i.e. (B*T*H*W)*F ) -# this computes two terms: -# 1. the absolute deviation of the mean of all the -# variables from zero. -# 2. the absolute deviation of the covariance matrix -# of x from the identity matrix. -# -# mcoeff is a scalar coefficient that attracts the means -# of the variables to zero. -# ccoeff is a Features*Features square matrix of -# coefficients for each term in the covariance matrix. -# c is the constant to which the variances (diagonal terms) are pinned. -def off_diagonal(x): - n, m = x.shape - assert n == m - return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() - - -def vc_cost_orig(x, std_coeff, cov_coeff): - """ - x: (B*T*H*W, F') - """ - x = x - x.mean(dim=0) - std_x = torch.sqrt(x.var(dim=0) + 0.0001) # (F',) - std_loss = torch.mean(F.relu(1 - std_x)) - cov_x = (x.T @ x) / (x.shape[0] - 1) # (F', F') - cov_loss = off_diagonal(cov_x).pow_(2).sum().div(x.shape[-1] ** 2 - x.shape[-1]) - loss = std_coeff * std_loss + cov_coeff * cov_loss - loss_dict = { - "std_loss": std_loss.item(), - "cov_loss": cov_loss.item(), - } - total_unweighted_loss = std_loss + cov_loss - return loss, total_unweighted_loss, loss_dict - - -# Class for VC loss. -# attracts the means of the variables to zero -# and the covariance matrix towards the identity. class VCLoss(nn.Module): + """Variance-Covariance loss attracting means to zero and covariance to identity.""" + def __init__(self, std_coeff, cov_coeff, proj=None): - """ - Variance-Covariance loss class - makes the means of the variables zero, and - makes the covariance matrix as close to identity as possible - """ super().__init__() self.std_coeff = std_coeff self.cov_coeff = cov_coeff self.proj = nn.Identity() if proj is None else proj + self.std_loss_fn = HingeStdLoss(std_margin=1.0) + self.cov_loss_fn = CovarianceLoss() def forward(self, x, actions=None): - # turn input into a samples*features 2D tensor. - # assumes feature dimension is dimension 1 (e.g. BFTHW) - # bs, f, t, h, w = x.shape - # fx = self.proj(x.permute(0, 2, 1, 3, 4).reshape(-1, f * h * w)) - x = x.transpose(0, 1).flatten(1).transpose(0, 1) # (B*T*H*W, F) - fx = self.proj(x) # (B*T*H*W, F') - return vc_cost_orig(fx, std_coeff=self.std_coeff, cov_coeff=self.cov_coeff) - - -###################################################### -# Contrastive loss on a batch of samples - + x = x.transpose(0, 1).flatten(1).transpose(0, 1) # [B*T*H*W, C] + fx = self.proj(x) # [B*T*H*W, C'] -def contrastive_cost( - x, temperature=0.1, negative_weight=1.0, subset_size=None, num_subsets=1 -): - """ - Contrastive loss function with efficient subset sampling - input x is a BFTHW array + std_loss = self.std_loss_fn(fx) + cov_loss = self.cov_loss_fn(fx) - This implements a contrastive objective that: - 1. Normalizes features to unit vectors - 2. Computes pairwise cosine similarities on random subsets - 3. Encourages samples to be diverse (pushes similar samples apart) - 4. Uses subset sampling for computational efficiency - - Args: - x: Input tensor of shape (B, F, T, H, W) - temperature: Temperature scaling parameter for similarities - negative_weight: Weight for negative pairs (diversity encouragement) - subset_size: Size of random subsets to sample (None means use all samples) - num_subsets: Number of random subsets to sample and average over - """ - # put feature dimension first - # turns x from BFTHW to FBTHW format - x = x.transpose(0, 1) - # flatten to samples*features 2D tensor - flattened_state = x.flatten(1).transpose(0, 1) # (B*T*H*W, F) - - num_samples = flattened_state.size(0) - - if subset_size is None or subset_size >= num_samples: - # Use all samples if subset_size is not specified or too large - subset_size = num_samples - num_subsets = 1 - - total_loss = 0.0 - - for _ in range(num_subsets): - # Randomly sample subset_size samples - indices = torch.randperm(num_samples, device=x.device)[:subset_size] - subset_samples = flattened_state[indices] - - # Normalize features to unit vectors - x_norm = F.normalize(subset_samples, p=2, dim=1) - - # Compute pairwise cosine similarities - similarities = torch.mm(x_norm, x_norm.t()) / temperature - - # Create mask to exclude diagonal (self-similarities) - mask = torch.eye(subset_size, device=x.device).bool() - - # Extract off-diagonal similarities (treating all as negative pairs for diversity) - off_diagonal_sims = similarities.masked_select(~mask) - - # Contrastive loss: encourage diversity by penalizing high similarities - # Using logsumexp for numerical stability - diversity_loss = torch.logsumexp(off_diagonal_sims, dim=0) - total_loss += diversity_loss - - # Average over subsets - total_loss = total_loss / num_subsets - - return negative_weight * total_loss + loss = self.std_coeff * std_loss + self.cov_coeff * cov_loss + total_unweighted_loss = std_loss + cov_loss + loss_dict = { + "std_loss": std_loss.item(), + "cov_loss": cov_loss.item(), + } + return loss, total_unweighted_loss, loss_dict class HingeStdLoss(torch.nn.Module): @@ -194,7 +71,7 @@ def __init__( def forward(self, x: torch.Tensor): """ Args: - x: Tensor[N, D] where N is number of samples, D is feature dimension + x: [N, D] where N is number of samples, D is feature dimension Returns: std_loss: Scalar tensor with the hinge loss on standard deviations """ @@ -205,32 +82,31 @@ def forward(self, x: torch.Tensor): class CovarianceLoss(torch.nn.Module): - def __init__(self, adjust_conv: bool = True): + def __init__(self): """ Penalizes off-diagonal elements of the covariance matrix to encourage feature decorrelation. - Args: - adjust_conv (bool, default=True): - If True, normalizes by (D - 1) where D is feature dimensionality. + + Normalizes by D * (D - 1) where D is feature dimensionality. """ super().__init__() - self.adjust_conv = adjust_conv + + def off_diagonal(self, x): + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() def forward(self, x: torch.Tensor): """ Args: - x: Tensor[N, D] where N is number of samples, D is feature dimension + x: [N, D] where N is number of samples, D is feature dimension """ batch_size = x.shape[0] num_features = x.shape[-1] x = x - x.mean(dim=0, keepdim=True) cov = (x.T @ x) / (batch_size - 1) # [D, D] # Calculate off-diagonal loss - diag_elements = torch.diag(cov).pow(2).sum() - cov_loss = (cov.pow(2).sum() - diag_elements) / num_features - - if self.adjust_conv: - cov_loss = cov_loss / (num_features - 1) + cov_loss = self.off_diagonal(cov).pow(2).mean() return cov_loss @@ -247,7 +123,7 @@ def __init__(self): def forward(self, x: torch.Tensor): """ Args: - x: Tensor[T, N, D] where T is time steps, N is batch size, D is feature dimension + x: [T, N, D] where T is time steps, N is batch size, D is feature dimension """ if x.shape[0] <= 1: return torch.tensor(0.0, device=x.device) @@ -268,8 +144,8 @@ def __init__(self, idm: nn.Module): def forward(self, x: torch.Tensor, actions: torch.Tensor): """ Args: - x: Tensor[T, B, D] - States across time steps - actions: Tensor[B, A, T] - Ground truth actions between consecutive states + x: [T, B, D] - States across time steps + actions: [B, A, T] - Ground truth actions between consecutive states """ if x.shape[0] <= 1 or actions is None: return torch.tensor(0.0, device=x.device) @@ -300,7 +176,6 @@ def __init__( idm_coeff: float = 0.0, idm: nn.Module = None, std_margin: float = 1, - adjust_conv: bool = True, first_t_only: bool = True, projector: nn.Module = None, spatial_as_samples: bool = False, @@ -323,7 +198,6 @@ def __init__( idm_coeff (float): Weight for inverse dynamics loss idm (nn.Module): Inverse dynamics model std_margin (float): Minimum desired std per feature - adjust_conv (bool): Normalize covariance loss by (D-1) first_t_only (bool): Use only first time slice for std/cov loss projector (nn.Module): Optional projection layer spatial_as_samples (bool): Treat spatial locations as samples @@ -344,32 +218,29 @@ def __init__( # Initialize individual loss components self.std_loss_fn = HingeStdLoss(std_margin=std_margin) - self.cov_loss_fn = CovarianceLoss(adjust_conv=adjust_conv) + self.cov_loss_fn = CovarianceLoss() self.sim_loss_fn = TemporalSimilarityLoss() self.idm_loss_fn = InverseDynamicsLoss(idm) if idm is not None else None def forward(self, x, actions=None): """ - x (Tensor[B, C, T, H, W]): - Input activations. Internally reshaped to either - (1, B, D) when `first_t_only=True` or ((TΒ·B), D) otherwise, - with D = CΒ·HΒ·W. - first_t_only (bool, default=True): - If True, only uses the first time‑slice for both std and cov loss; - if False, flattens all time‑slices into the batch dimension. + Args: + x: [B, C, T, H, W] - Input activations. Internally reshaped to either + [1, B, D] when first_t_only=True or [T*B, D] otherwise, with D=C*H*W. + actions: [B, A, T] - Optional actions for IDM loss """ b, c, t, h, w = x.shape # divergent gradient paths for x_unprojected and x_projected - x_unprojected = x.permute(2, 0, 1, 3, 4).reshape(t, b, -1) # [t, b, c*h*w] + x_unprojected = x.permute(2, 0, 1, 3, 4).reshape(t, b, -1) # [T, B, C*H*W] - x_flat = x.permute(0, 2, 3, 4, 1).reshape(-1, c) # [b*t*h*w, c] - x_proj = self.projector(x_flat) # [b*t*h*w, c_out] + x_flat = x.permute(0, 2, 3, 4, 1).reshape(-1, c) # [B*T*H*W, C] + x_proj = self.projector(x_flat) # [B*T*H*W, C_out] c_out = x_proj.shape[-1] - x_projected = x_proj.view(b, t, h, w, c_out) # [b, t, h, w, c_out] + x_projected = x_proj.view(b, t, h, w, c_out) # [B, T, H, W, C_out] x_projected_reshaped = x_projected.permute(2, 0, 1, 3, 4).reshape( t, b, -1 - ) # [t, b, c_out*h*w] + ) # [T, B, C_out*H*W] # SIM_T LOSS if self.sim_t_after_proj: @@ -388,29 +259,29 @@ def forward(self, x, actions=None): # STD and COV LOSS if self.spatial_as_samples: if self.first_t_only: - # Use only first time: [b*h*w, c_out] + # Use only first time: [B*H*W, C_out] x_for_vc = x_projected[:, 0].reshape(b * h * w, c_out) assert x_for_vc.shape == (b * h * w, c_out) else: - # Use all times: [b*t*h*w, c_out] + # Use all times: [B*T*H*W, C_out] x_for_vc = x_projected.reshape(-1, c_out) assert x_for_vc.shape == (b * t * h * w, c_out) else: x_for_vc = x_projected.permute(0, 1, 4, 2, 3).reshape( b, t, -1 - ) # [b, t, c_out*h*w] + ) # [B, T, C_out*H*W] if self.first_t_only: - # Use only first time: [b, c_out*h*w] + # Use only first time: [B, C_out*H*W] x_for_vc = x_for_vc[:, 0] assert x_for_vc.shape == (b, c_out * h * w) else: - # Use all times: [b*t, c_out*h*w] + # Use all times: [B*T, C_out*H*W] x_for_vc = x_for_vc.reshape(-1, x_for_vc.size(-1)) assert x_for_vc.shape == (b * t, c_out * h * w) - # [b*t, c_out*h*w] if first_t_only=False and spatial_as_samples=False - # or [b, c_out*h*w] if first_t_only=True and spatial_as_samples=False - # or [b*h*w, c_out] if first_t_only=True spatial_as_samples=True - # or [b*t*h*w, c_out] if first_t_only=False spatial_as_samples=True + # [B*T, C_out*H*W] if first_t_only=False and spatial_as_samples=False + # or [B, C_out*H*W] if first_t_only=True and spatial_as_samples=False + # or [B*H*W, C_out] if first_t_only=True spatial_as_samples=True + # or [B*T*H*W, C_out] if first_t_only=False spatial_as_samples=True std_loss = self.std_loss_fn(x_for_vc) cov_loss = self.cov_loss_fn(x_for_vc) @@ -432,118 +303,53 @@ def forward(self, x, actions=None): return total_weighted_loss, total_unweighted_loss, loss_dict -class ContrastiveLoss(nn.Module): - def __init__( - self, - temperature=0.1, - negative_weight=1.0, - proj=None, - subset_size=None, - num_subsets=1, - ): - """ - Contrastive loss class with efficient subset sampling - Encourages diversity among samples by penalizing high similarities +class VICRegLoss(nn.Module): + """VICReg loss combining invariance, variance (std), and covariance terms.""" - Args: - temperature: Temperature scaling parameter for similarities - negative_weight: Weight for the contrastive term - proj: Optional projection layer applied before computing loss - subset_size: Size of random subsets to sample (None means use all samples) - num_subsets: Number of random subsets to sample and average over - """ + def __init__(self, std_coeff=1.0, cov_coeff=1.0): super().__init__() - self.temperature = temperature - self.negative_weight = negative_weight - self.subset_size = subset_size - self.num_subsets = num_subsets - self.proj = nn.Identity() if proj is None else proj - - def forward(self, x): - # apply optional projection before computing contrastive cost - if self.proj is not None and not isinstance(self.proj, nn.Identity): - # apply projection to features: assumes feature dimension is dimension 1 (BFTHW) - fx = self.proj(x.transpose(0, 1).flatten(1).transpose(0, 1)) - # reshape back to BFTHW format - b, f, t, h, w = x.shape - projected_f = fx.size(1) # new feature dimension after projection - fx = fx.transpose(0, 1).view(projected_f, b, t, h, w).transpose(0, 1) - return contrastive_cost( - fx, - self.temperature, - self.negative_weight, - self.subset_size, - self.num_subsets, - ) - else: - return contrastive_cost( - x, - self.temperature, - self.negative_weight, - self.subset_size, - self.num_subsets, - ) - - -###################################################### -# VICReg loss function - + self.std_coeff = std_coeff + self.cov_coeff = cov_coeff + self.std_loss_fn = HingeStdLoss(std_margin=1.0) + self.cov_loss_fn = CovarianceLoss() -class VICRegLoss(nn.Module): - """VICReg loss module. - - Args: - sim_loss_weight: Weight for similarity (invariance) loss - var_loss_weight: Weight for variance loss - cov_loss_weight: Weight for covariance loss - """ - - def __init__(self, var_loss_weight=1.0, cov_loss_weight=1.0): - super().__init__() - self.var_loss_weight = var_loss_weight - self.cov_loss_weight = cov_loss_weight - def forward(self, z1, z2): """Compute VICReg loss. - + Args: - z1: First projection tensor (batch_size, features) - z2: Second projection tensor (batch_size, features) - + z1: [B, D] - First projection tensor + z2: [B, D] - Second projection tensor + Returns: - tuple: (total_loss, sim_loss, var_loss, cov_loss) + dict with keys: loss, invariance_loss, var_loss, cov_loss """ - batch_size = z1.size(0) - # Invariance loss (similarity) sim_loss = F.mse_loss(z1, z2) - - # Variance loss - z1_std = torch.sqrt(z1.var(dim=0) + 1e-4) - z2_std = torch.sqrt(z2.var(dim=0) + 1e-4) - var_loss = torch.mean(F.relu(1 - z1_std)) + torch.mean(F.relu(1 - z2_std)) - - # Covariance loss - z1_centered = z1 - z1.mean(dim=0) - z2_centered = z2 - z2.mean(dim=0) - z1_cov = torch.mm(z1_centered.T, z1_centered) / (batch_size - 1) - z2_cov = torch.mm(z2_centered.T, z2_centered) / (batch_size - 1) - - cov_loss = (z1_cov.pow(2).sum() - z1_cov.diagonal().pow(2).sum()) / (z1_cov.size(0)**2 - z1_cov.size(0)) + \ - (z2_cov.pow(2).sum() - z2_cov.diagonal().pow(2).sum()) / (z2_cov.size(0)**2 - z2_cov.size(0)) - - total_loss = sim_loss + self.var_loss_weight * var_loss + self.cov_loss_weight * cov_loss - - return {"loss": total_loss, "invariance_loss": sim_loss, "var_loss": var_loss, "cov_loss": cov_loss} + + # Variance loss (applied to both views and summed) + var_loss = self.std_loss_fn(z1) + self.std_loss_fn(z2) + + # Covariance loss (applied to both views and summed) + cov_loss = self.cov_loss_fn(z1) + self.cov_loss_fn(z2) + + total_loss = sim_loss + self.std_coeff * var_loss + self.cov_coeff * cov_loss + + return { + "loss": total_loss, + "invariance_loss": sim_loss, + "var_loss": var_loss, + "cov_loss": cov_loss, + } ###################################################### -# BCS (Batched Characteristic Slicing) loss for LE-JEPA +# BCS (Batched Characteristic Slicing) loss for SIGReg def all_reduce(x, op): """All-reduce operation for distributed training.""" import torch.distributed as dist + if dist.is_available() and dist.is_initialized(): op = dist.ReduceOp.__dict__[op] dist.all_reduce(x, op=op) @@ -569,8 +375,8 @@ def epps_pulley(x, t_min=-3, t_max=3, n_points=10): class BCS(nn.Module): - """BCS (Batched Characteristic Slicing) loss for LE-JEPA.""" - + """BCS (Batched Characteristic Slicing) loss for SIGReg.""" + def __init__(self, num_slices=256, lmbd=10.0): super().__init__() self.num_slices = num_slices @@ -593,4 +399,3 @@ def forward(self, z1, z2): invariance_loss = F.mse_loss(z1, z2).mean() total_loss = invariance_loss + self.lmbd * bcs return {"loss": total_loss, "bcs_loss": bcs, "invariance_loss": invariance_loss} - diff --git a/eb_jepa/nn_utils.py b/eb_jepa/nn_utils.py index 8666b83..e524d2a 100644 --- a/eb_jepa/nn_utils.py +++ b/eb_jepa/nn_utils.py @@ -15,7 +15,9 @@ def init_module_weights(m, std: float = 0.02): m: PyTorch module to initialize std: Standard deviation for truncated normal initialization (default: 0.02) """ - if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.Linear)): + if isinstance( + m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.Linear) + ): nn.init.trunc_normal_(m.weight, std=std) if m.bias is not None: nn.init.constant_(m.bias, 0) @@ -26,21 +28,21 @@ class TemporalBatchMixin: Mixin class that handles automatic temporal batching for 4D/5D tensors. This mixin provides a unified forward() method that: - - For 5D tensors (B, C, T, H, W): flattens temporal dim, applies _forward(), restores shape - - For 4D tensors (B, C, H, W): directly applies _forward() + - For 5D tensors [B, C, T, H, W]: flattens temporal dim, applies _forward(), restores shape + - For 4D tensors [B, C, H, W]: directly applies _forward() Subclasses must implement _forward(self, x) for 4D tensors. """ def _forward(self, x): """ - Process 4D tensor (B, C, H, W). Must be implemented by subclasses. + Process 4D tensor [B, C, H, W]. Must be implemented by subclasses. Args: - x: Input tensor of shape (B, C, H, W) + x: Input tensor of shape [B, C, H, W] Returns: - Output tensor of shape (B, C_out, H_out, W_out) + Output tensor of shape [B, C_out, H_out, W_out] """ raise NotImplementedError("Subclasses must implement _forward()") @@ -49,12 +51,15 @@ def forward(self, x): Forward pass supporting both 4D and 5D tensors. Args: - x: Input tensor of shape (B, C, H, W) or (B, C, T, H, W) + x: Input tensor of shape [B, C, H, W] or [B, C, T, H, W] Returns: Output tensor with same batch and temporal dimensions as input """ - assert x.ndim in [4, 5], "Supports only 4D (B,C,H,W) or 5D (B,C,T,H,W) tensors" + assert x.ndim in [ + 4, + 5, + ], "Supports only 4D [B, C, H, W] or 5D [B, C, T, H, W] tensors" if x.ndim == 5: b = x.shape[0] x = rearrange(x, "b c t h w -> (b t) c h w") diff --git a/eb_jepa/planning.py b/eb_jepa/planning.py index db3fe4a..ab74ecb 100644 --- a/eb_jepa/planning.py +++ b/eb_jepa/planning.py @@ -1,27 +1,27 @@ import os import time from abc import ABC, abstractmethod -from io import BytesIO from typing import Callable, List, NamedTuple, Optional -import cv2 -import imageio -import matplotlib.pyplot as plt import numpy as np import pandas as pd -import seaborn as sns import torch from einops import rearrange from omegaconf import OmegaConf -from PIL import Image from tqdm import tqdm from eb_jepa.logging import get_logger -from eb_jepa.visualize_samples import save_gif, save_gif_HWC, show_images, to3channels +from eb_jepa.vis_utils import ( + analyze_distances, + create_comparison_gif, + plot_losses, + save_decoded_frames, + save_gif, + show_images, +) logger = get_logger(__name__) -FIGSIZE_BASE = (4.0, 3.0) planner_name_map = { "cem": "CEMPlanner", "mppi": "MPPIPlanner", @@ -163,140 +163,6 @@ def main_unroll_eval( return results -def create_comparison_gif( - gt_seq, - pred_seq_true, - pred_seq_random, - gt_dec=None, - save_path="comparison.gif", - fps=15, -): - """ - Inputs: - - gt_seq: Ground truth sequence of shape (B, T, H, W, C), uint8, [0, 255] - - gt_dec: Decoded ground truth sequence of shape (B, T, H, W, C), uint8, [0, 255] - - pred_seq_true: Decoded predictions using true actions of shape (B, T, H, W, C), uint8, [0, 255] - - pred_seq_random: Decoded predictions using random actions of shape (B, T, H, W, C), uint8, [0, 255] - Create a three-column visualization: - - Left: Ground truth observations - - Middle: Decoded predictions using true actions - - Right: Decoded predictions using random actions - - Display min(batch_size, 4) rows of sequences. - """ - b = gt_seq.shape[0] - num_rows = min(b, 4) - - if gt_dec is not None: - seq_length = min( - gt_seq.shape[1], - gt_dec.shape[1], - pred_seq_true.shape[1], - pred_seq_random.shape[1], - ) - else: - seq_length = min( - gt_seq.shape[1], pred_seq_true.shape[1], pred_seq_random.shape[1] - ) - - img_height, img_width = gt_seq.shape[2], gt_seq.shape[3] - - # Determine number of columns (3 or 4 depending on if gt_dec is provided) - num_cols = 4 if gt_dec is not None else 3 - padding = 0 - title_height = 15 - - titles = ["GT"] - if gt_dec is not None: - titles.append("Dec GT") - titles.extend(["GT Act", "Rand Act"]) - - frames = [] - for t in range(seq_length): - # Create a black canvas - canvas_height = title_height + num_rows * (img_height + padding) + padding - canvas_width = num_cols * (img_width + padding) + padding - canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8) - - # Add column titles - for col, title in enumerate(titles): - col_x = padding + col * (img_width + padding) + img_width // 2 - # Get text size for proper centering - (text_width, _), _ = cv2.getTextSize( - title, cv2.FONT_HERSHEY_SIMPLEX, 0.3, 1 - ) - cv2.putText( - canvas, - title, - (col_x - text_width // 2, title_height - 5), - cv2.FONT_HERSHEY_SIMPLEX, - 0.3, - (255, 255, 255), - 1, - cv2.LINE_AA, - ) - for row in range(num_rows): - # Base y-coordinate for this row - base_y = title_height + padding + row * (img_height + padding) - - # Ground truth (first column) - gt_frame = to3channels(gt_seq[row, t]) # Shape should be (H, W, C) - canvas[base_y : base_y + img_height, padding : padding + img_width] = ( - gt_frame - ) - - col_idx = 1 - - # Decoded ground truth (optional second column) - if gt_dec is not None: - gt_dec_frame = to3channels(gt_dec[row, t]) - col_x = padding + col_idx * (img_width + padding) - canvas[base_y : base_y + img_height, col_x : col_x + img_width] = ( - gt_dec_frame - ) - col_idx += 1 - - # Prediction with true actions - pred_true_frame = to3channels(pred_seq_true[row, t]) - col_x = padding + col_idx * (img_width + padding) - canvas[base_y : base_y + img_height, col_x : col_x + img_width] = ( - pred_true_frame - ) - col_idx += 1 - - # Prediction with random actions - pred_random_frame = to3channels(pred_seq_random[row, t]) - col_x = padding + col_idx * (img_width + padding) - canvas[base_y : base_y + img_height, col_x : col_x + img_width] = ( - pred_random_frame - ) - - # Add timestep indicator in the bottom right corner - (text_width, text_height), _ = cv2.getTextSize( - f"t={t}", cv2.FONT_HERSHEY_SIMPLEX, 0.3, 1 - ) - text_x = canvas_width - text_width - 10 # 10px margin from right edge - text_y = canvas_height - 10 # 10px margin from bottom edge - cv2.putText( - canvas, - f"t={t}", - (text_x, text_y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.3, - (255, 255, 255), - 1, - cv2.LINE_AA, - ) - - frames.append(canvas) - - # Save as GIF - imageio.mimsave(save_path, frames, fps=fps, loop=0) - logger.info(f" βœ“ Saved comparison GIF: {os.path.basename(save_path)}") - - return frames - - ### Main planning eval loop ### def main_eval( plan_cfg, @@ -337,7 +203,6 @@ def main_eval( os.makedirs(ep_plan_vis_dir, exist_ok=True) if plan_cfg.task_specification.goal_source == "dset": - # TODO obs_slice, a, loc, _, _ = next(iter(loader)) # obs, init_loc = obs_slice[0], loc[0] # goal_img, goal_loc = obs_slice[-1], loc[-1] # [C, H, W] uint8 tensor @@ -431,23 +296,33 @@ def main_eval( successes.append(success) distances.append(state_dist) - coord_diffs, _repr_diffs = agent.analyze_distances( - episode_observations[-1], - episode_infos[-1], - str(ep_folder / "agent"), - ) - agent.plot_losses( - prev_losses, - prev_elite_losses_mean, - prev_elite_losses_std, - work_dir=ep_folder, - ) - save_path = f"{ep_folder}/agent_steps_{'succ' if success else 'fail'}.gif" + if plan_cfg.logging.get("optional_plots", True): + analyze_distances( + episode_observations[-1], + episode_infos[-1], + str(ep_folder / "agent"), + goal_position=agent.goal_position, + goal_state=agent.goal_state, + normalizer=agent.normalizer, + model=agent.model, + objective=agent.objective, + device=agent.device, + ) + plot_losses( + prev_losses, + prev_elite_losses_mean, + prev_elite_losses_std, + work_dir=ep_folder, + num_act_stepped=agent.num_act_stepped, + ) + save_path = f"{ep_folder}/agent_steps_{'succ' if success else 'fail'}.gif" save_gif( episode_observations[-1], save_path=save_path, show_frame_numbers=True, fps=20, + init_frame=observations[0], + goal_frame=goal_img, ) logger.info(f"GIF saved to {save_path}") episode_end_time = time.time() # Add this line @@ -533,31 +408,39 @@ def set_goal(self, goal_state, goal_position=None): def unroll(self, obs_init, actions, repeat_batch=True): """ - Called by self.planner.cost_function() - actions: B A T - obs_init: B C T H W + Unroll the model for planning. + + Args: + obs_init: [B, C, T, H, W] + actions: [B, A, T] + + Returns: + predicted_states: [B, D, T, H, W] """ batch_size = actions.shape[0] nsteps = actions.shape[2] if repeat_batch: obs_init = obs_init.repeat(batch_size, 1, 1, 1, 1) - # unroll_time_start = time.time() - predicted_states = self.model.unrolln( + predicted_states, _ = self.model.unroll( obs_init, actions, - nsteps, + nsteps=nsteps, + unroll_mode="autoregressive", ctxt_window_time=self.plan_cfg["ctxt_window_time"] if self.plan_cfg else 1, + compute_loss=False, + return_all_steps=False, ) - # logging.info(f"unroll time: {time.time() - unroll_time_start:.4f}s") return predicted_states def decode_loc_to_pixel(self, predicted_encs, wall_x=None, door_y=None): """ Decode the predicted encodings into frames. + Args: - predicted_encs: Tensor of shape (B, D, T, H, W) + predicted_encs: [B, D, T, H, W] + Returns: - np.array of shape (B, T, H, W, C) on cpu for visualization. + np.array of shape [B, T, H, W, C] on cpu for visualization. """ assert self.loc_prober is not None B, D, T, H, W = predicted_encs.shape @@ -580,137 +463,6 @@ def act(self, obs, steps_left=None, t0=False, plan_vis_path=None): self._prev_elite_losses_std = planning_result.prev_elite_losses_std return planning_result.actions[: self.num_act_stepped] # T, A - def plot_losses( - self, losses, elite_losses_mean, elite_losses_std, work_dir, frameskip=1 - ): - """ - Input: - prev_losses, List[Tensor, size= (n_opt_steps, n_losses)]. - For now, n_losses = 1. - """ - losses = torch.stack(losses, dim=0).detach().cpu().numpy() - elite_losses_mean = torch.stack(elite_losses_mean, dim=0).detach().cpu().numpy() - elite_losses_std = torch.stack(elite_losses_std, dim=0).detach().cpu().numpy() - n_timesteps, n_opt_steps, n_losses = losses.shape - sns.set_theme() - for i in range(n_losses): - total_plots = min(16, n_timesteps) - rows = 1 - cols = int(np.ceil(total_plots / rows)) - fig_width = FIGSIZE_BASE[0] * cols - fig_height = FIGSIZE_BASE[1] * rows - plt.figure(figsize=(fig_width, fig_height), dpi=300) - steps = np.linspace(0, n_timesteps - 1, total_plots, dtype=int) - for j, step in enumerate(steps): - ax = plt.subplot(rows, cols, j + 1) - if n_opt_steps > 1: - sns.lineplot(data=losses[step, :, i]) - sns.lineplot(data=elite_losses_mean[step, :, i]) - ax.fill_between( - range(n_opt_steps), - elite_losses_mean[step, :, i] - elite_losses_std[step, :, i], - elite_losses_mean[step, :, i] + elite_losses_std[step, :, i], - alpha=0.3, - ) - else: - ax.bar( - 0, losses[step, 0, i] - ) # Plot a bar chart if only one opt step - ax.bar(0, elite_losses_mean[step, 0, i]) - ax.errorbar( - 0, - elite_losses_mean[step, 0, i], - yerr=elite_losses_std[step, 0, i], - fmt="none", - capsize=5, - ) - ax.set_title(f"Episode step {step * frameskip * self.num_act_stepped}") - ax.tick_params(axis="both") - ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) - plt.tight_layout() - plt.savefig(work_dir / f"losses_{i}.pdf", bbox_inches="tight") - plt.close() - - def analyze_distances( - self, - obses, - infos, - plot_prefix, - ): - """ - Input: - obses: Tensor: [B, c, h, w] with B = env.max_episode_steps + 1 - """ - # TODO: have a more general signal than dot_position, which is specific to dot envs, - # called proprioception, allowed by an - # env wrapper wrapping all possible envs. - coords = torch.stack([x["dot_position"] for x in infos]).unsqueeze(1) # B 1 c - distances = ( - torch.norm( - coords[..., -1, :3] - self.goal_position[:3].unsqueeze(0), dim=-1 - ) - .detach() - .cpu() - ) - sns.set_theme() - FIGSIZE = (4.0, 3.0) - self.plot_distances(distances, plot_prefix + "_distances.pdf", figsize=FIGSIZE) - - # Normalizer takes [.. c h w] - all_states = ( - self.normalizer.normalize_state( - torch.cat([obses, self.goal_state.unsqueeze(0)]) - ) - .unsqueeze(-3) - .to(self.device) - ) # B c 1 h w - # The encoder takes batch of single states, of dim [len_ep, C, H, W] so no temporal dependency - all_encs = self.model.encode(all_states) # B d 1 h w - diffs = self.compute_embed_differences(all_encs).detach().cpu() - self.plot_distances( - diffs, - plot_prefix + "_rep_distance_visual.pdf", - figsize=FIGSIZE, - xlabel="Timesteps", - ylabel="Rep distance to goal", - ) - - all_encs_excluded = all_encs[:-1] - all_objectives = self.objective(all_encs_excluded).detach().cpu() - self.plot_distances( - all_objectives, - plot_prefix + "_objectives.pdf", - figsize=FIGSIZE, - xlabel="Timesteps", - ylabel="Objective values", - ) - - return distances, diffs - - def plot_distances( - self, - data, - plot_prefix="", - figsize=(4.0, 3.0), - xlabel="Timesteps", - ylabel="Distance to goal", - ): - plt.figure(figsize=figsize, dpi=300) - sns.lineplot(data=data) - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.tight_layout() - plt.savefig(plot_prefix, bbox_inches="tight") - plt.close() - - def compute_embed_differences(self, all_encs): - """ - Input: all_encs: - visual: [T D 1 H W] - """ - sq_diff = (all_encs[:-1] - all_encs[-1:]) ** 2 - return sq_diff.mean(dim=tuple(range(1, all_encs.ndim))) - ### Planning objectives to minimize ### class ReprTargetDistMPCObjective: @@ -728,10 +480,11 @@ def __init__( def __call__(self, encodings: torch.Tensor, keepdims: bool = False) -> torch.Tensor: """ Args: - encodings: tensor (B D T H W) - target_enc: tensor (B D T H W) + encodings: [B, D, T, H, W] + keepdims: if True, return [B, T], else return [B] + Returns: - diff: tensor, (B T) or (B) if sum_all_diffs or not keepdims + diff: [B, T] else [B] if sum_all_diffs or not keepdims """ if self.sum_all_diffs: keepdims = True @@ -781,117 +534,6 @@ def cost_function( predicted_encs = self.unroll(obs_init, actions) return self.objective(predicted_encs) - def save_decoded_frames( - self, pred_frames_over_iterations, costs, plan_vis_path, overlay=True - ): - # costs: List[float] of length iterations - # pred_frames_over_iterations: List[(T, H, W, C)] of length iterations - if pred_frames_over_iterations is not None and plan_vis_path is not None: - pass - - frames = [] - global_min_cost = np.min(costs) - global_max_cost = np.max(costs) - # Pre-calculate the normalized positions for all costs - all_normalized_costs = [] - for i in range(len(costs)): - # For each iteration, normalize all costs seen so far - current_costs = costs[: i + 1] - if len(current_costs) > 1: - # Normalize using global min/max for consistent scaling - normalized = (current_costs - global_min_cost) / ( - global_max_cost - global_min_cost + 1e-10 - ) - all_normalized_costs.append(normalized) - else: - all_normalized_costs.append( - np.array([0.5]) - ) # Default for single value - - for i, pred_frames in enumerate(pred_frames_over_iterations): - # pred_frames.shape: (T, H, W, C) - if overlay: - overlay_frames = [] - for frame_idx, frame in enumerate(pred_frames): - # Create a copy of the frame to draw on - frame_with_overlay = frame.copy() - - # Get frame dimensions - h, w = frame.shape[0], frame.shape[1] - - # Calculate scale factors for dimensions - scale_factor = min(h, w) / 1000 # Base scale on 500px reference - font_scale = max(0.3, scale_factor * 0.5) - line_thickness = max(1, int(scale_factor)) - margin = int(h * 0.02) # 2% of height - - # Get normalized costs for this iteration - current_costs = costs[: i + 1] - if len(current_costs) > 1: - normalized_costs = all_normalized_costs[i] - # Map to pixel space (top is low cost, bottom is high cost) - top_margin = int(h * 0.05) # 5% from top - bottom_margin = int(h * 0.05) # 5% from bottom - y_positions = (1 - normalized_costs) * ( - h - top_margin - bottom_margin - ) + top_margin - - # Add text showing the iteration number - cv2.putText( - frame_with_overlay, - f"Iter {i+1}", - (margin, margin + int(h * 0.1)), - cv2.FONT_HERSHEY_SIMPLEX, - font_scale, - (200, 200, 200), - line_thickness, - ) - - overlay_frames.append(frame_with_overlay) - frames.extend(overlay_frames) - else: - plt.clf() - plt.figure(figsize=(10, 10)) - plt.plot(costs[: i + 1]) - plt.title(f"Iteration {i}") - plt.xlabel("Iteration") - plt.ylabel("Loss") - plt.xlim(0, len(costs)) - plt.ylim(min(costs), max(costs)) - buf = BytesIO() - plt.savefig(buf, format="png") - buf.seek(0) - img = Image.open(buf) - img = np.array(img) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - img = cv2.resize(img, (256, 256)) - - combined_frames = [] - for frame in pred_frames: - frame = cv2.resize(frame, (256, 256)) - combined_frame = np.concatenate((img, frame), axis=1) - combined_frames.append(combined_frame) - frames.extend(combined_frames) - # Save GIF - filename = f"{plan_vis_path}.gif" - save_gif_HWC(frames, filename, fps=30) - logger.info(f"Plan decoding video saved to {plan_vis_path}") - - # Save last iteration frames as PDF - last_pred_frames = pred_frames_over_iterations[-1] - pdf_filename = f"{plan_vis_path}_last_frames.pdf" - n_frames = len(last_pred_frames) - show_images( - last_pred_frames.transpose(0, 3, 1, 2), - nrow=n_frames, - titles=None, - save_path=pdf_filename, - close_fig=True, - first_channel_only=False, - clamp=False, - ) - logger.info(f"Last iteration frames saved to {pdf_filename}") - ### Specific planning optimizers ### class CEMPlanner(Planner): @@ -1003,7 +645,7 @@ def plan( pred_frames_over_iterations.append(pred_frames.squeeze(0)) # [T H W 3]: uint 8 in [0, 255] if self.decode_each_iteration: - self.save_decoded_frames(pred_frames_over_iterations, losses, plan_vis_path) + save_decoded_frames(pred_frames_over_iterations, losses, plan_vis_path) # Return the first action(s) a = mean @@ -1025,7 +667,6 @@ def __init__( plan_length: int = 15, action_dim: int = 2, max_std: float = 2, - min_std: float = 0.05, num_elites: int = 64, temperature: float = 0.005, max_norms: Optional[List[float]] = None, @@ -1041,7 +682,6 @@ def __init__( self.action_dim = action_dim self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.max_std = max_std - self.min_std = min_std self.num_elites = num_elites self.temperature = temperature self.max_norms = max_norms @@ -1137,7 +777,7 @@ def plan( pred_frames_over_iterations.append(pred_frames.squeeze(0)) # [T H W 3]: uint 8 in [0, 255] if self.decode_each_iteration: - self.save_decoded_frames(pred_frames_over_iterations, losses, plan_vis_path) + save_decoded_frames(pred_frames_over_iterations, losses, plan_vis_path) # Select action score = score.cpu().numpy() actions = elite_actions[ diff --git a/eb_jepa/schedulers.py b/eb_jepa/schedulers.py index 9229cfb..c4d681a 100644 --- a/eb_jepa/schedulers.py +++ b/eb_jepa/schedulers.py @@ -1,8 +1,7 @@ -import torch from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR -class CosineWithWarmup(torch.optim.lr_scheduler._LRScheduler): +class CosineWithWarmup: """ A learning rate scheduler that combines linear warmup followed by cosine annealing. @@ -37,10 +36,15 @@ def __init__( self.scheduler = SequentialLR( optimizer, schedulers=[warmup, cosine], milestones=[self.warmup_steps] ) - super().__init__(optimizer, last_epoch) def step(self): self.scheduler.step() def get_last_lr(self): return self.scheduler.get_last_lr() + + def state_dict(self): + return self.scheduler.state_dict() + + def load_state_dict(self, state_dict): + self.scheduler.load_state_dict(state_dict) diff --git a/examples/ac_video_jepa/heads.py b/eb_jepa/state_decoder.py similarity index 56% rename from examples/ac_video_jepa/heads.py rename to eb_jepa/state_decoder.py index 3806de8..2f7e862 100644 --- a/examples/ac_video_jepa/heads.py +++ b/eb_jepa/state_decoder.py @@ -2,9 +2,7 @@ class MLPXYHead(nn.Module): - """ - A head to recover the xy location from features - """ + """A head to recover the xy location from features.""" def __init__(self, input_shape, normalizer=None): # input_shape = (C, H, W) super().__init__() @@ -15,18 +13,17 @@ def __init__(self, input_shape, normalizer=None): # input_shape = (C, H, W) def forward(self, x): """ - Input: - x: (bs, c, t, h, w) - Output: - pred: (bs, 2, t) + Args: + x: [B, C, T, H, W] + Returns: + pred: [B, 2, T] """ bs, c, t, h, w = x.shape - # (bs, c, t, 1, 1) --> (bs * t, c, 1, 1) - x = x.permute(0, 2, 1, 3, 4) # (bs, t, c, 1, 1) - x = x.reshape(bs * t, c, h, w) # (bs * t, c, 1, 1) + x = x.permute(0, 2, 1, 3, 4) # [B, T, C, H, W] + x = x.reshape(bs * t, c, h, w) # [B*T, C, H, W] - x = x.squeeze(-1).squeeze(-1) # (bs * t, c, 1, 1) --> (bs * t, c) + x = x.squeeze(-1).squeeze(-1) # [B*T, C] pred = self.mlp(x) diff --git a/eb_jepa/training_utils.py b/eb_jepa/training_utils.py new file mode 100644 index 0000000..81a33b0 --- /dev/null +++ b/eb_jepa/training_utils.py @@ -0,0 +1,419 @@ +import os +import random +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from omegaconf import DictConfig, OmegaConf + +from eb_jepa.logging import get_logger + +logger = get_logger(__name__) + + +def setup_device(device: str = "auto") -> torch.device: + """Set up the compute device. Options: 'auto', 'cuda', or 'cpu'.""" + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + logger.info(f"Using device: {device}") + return device + + +def setup_seed(seed: int) -> None: + """Set random seeds for Python, NumPy, and PyTorch for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + logger.info(f"Random seed set to {seed}") + + +def setup_wandb( + project: str, + config: Union[Dict, DictConfig], + run_dir: Union[str, Path], + run_name: Optional[str] = None, + resume: bool = True, + tags: Optional[List[str]] = None, + group: Optional[str] = None, + enabled: bool = True, + sweep_id: Optional[str] = None, +): + """Initialize W&B with safe resume (preserves existing run metadata on resume).""" + # Respect WANDB_DISABLED environment variable (used by wandb itself) + if os.environ.get("WANDB_DISABLED", "").lower() in ("true", "1", "yes"): + logger.info("W&B logging disabled via WANDB_DISABLED environment variable") + return None + + if not enabled: + logger.info("W&B logging disabled") + return None + + import wandb + + run_dir = Path(run_dir) + run_dir.mkdir(parents=True, exist_ok=True) + run_id_file = run_dir / "wandb_run_id.txt" + + # Convert OmegaConf to dict if needed + if isinstance(config, DictConfig): + config = OmegaConf.to_container(config, resolve=True) + + # Handle wandb sweep registration via environment variables + # This is how wandb associates runs with sweeps + if sweep_id: + os.environ["WANDB_SWEEP_ID"] = sweep_id + logger.info(f"Registering run with wandb sweep: {sweep_id}") + if tags: + tags = list(tags) + [f"sweep_{sweep_id}"] + else: + tags = [f"sweep_{sweep_id}"] + + # Check if we should resume an existing run + if resume and run_id_file.exists(): + with open(run_id_file, "r") as f: + existing_run_id = f.read().strip() + + # For sweep runs, use environment variables for resume + if sweep_id: + os.environ["WANDB_RUN_ID"] = existing_run_id + os.environ["WANDB_RESUME"] = "allow" + wandb_config = { + "project": project, + "dir": str(run_dir), + "config": config, + } + if run_name: + wandb_config["name"] = run_name + if tags: + wandb_config["tags"] = tags + if group: + wandb_config["group"] = group + run = wandb.init(**wandb_config) + logger.info(f"Resumed W&B run: {existing_run_id} in sweep {sweep_id}") + return run + + # SAFE RESUME: Only pass id and resume flag - do NOT pass name/config/tags + # This prevents overwriting existing run metadata on W&B + wandb_config = { + "project": project, + "dir": str(run_dir), + "id": existing_run_id, + "resume": "must", # "must" = fail if run doesn't exist (safer than "allow") + } + if group: + wandb_config["group"] = group + + try: + run = wandb.init(**wandb_config) + logger.info( + f"Resumed W&B run: {existing_run_id} (existing config preserved)" + ) + return run + except wandb.errors.UsageError: + # Run doesn't exist anymore on W&B, create new one + logger.warning(f"W&B run {existing_run_id} not found, creating new run") + run_id_file.unlink() # Remove stale run ID file + + # NEW RUN: Pass all configuration + wandb_config = { + "project": project, + "dir": str(run_dir), + "config": config, + } + if run_name: + wandb_config["name"] = run_name + if tags: + wandb_config["tags"] = tags + if group: + wandb_config["group"] = group + + run = wandb.init(**wandb_config) + with open(run_id_file, "w") as f: + f.write(run.id) + logger.info(f"Created W&B run: {run.id}") + + return run + + +def save_checkpoint( + path: Union[str, Path], + model: nn.Module, + optimizer: Optional[optim.Optimizer] = None, + scheduler: Optional[Any] = None, + epoch: int = 0, + step: int = 0, + scaler: Optional[Any] = None, + **extra_state, +) -> None: + """Save a training checkpoint (model, optimizer, scheduler, scaler, extra_state).""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + checkpoint = { + "epoch": epoch, + "step": step, + "model_state_dict": model.state_dict(), + } + + if optimizer is not None: + checkpoint["optimizer_state_dict"] = optimizer.state_dict() + if scheduler is not None: + checkpoint["scheduler_state_dict"] = scheduler.state_dict() + if scaler is not None: + checkpoint["scaler_state_dict"] = scaler.state_dict() + + checkpoint.update(extra_state) + + torch.save(checkpoint, path) + logger.info(f"Saved checkpoint: {path}") + + +def load_checkpoint( + path: Union[str, Path], + model: nn.Module, + optimizer: Optional[optim.Optimizer] = None, + scheduler: Optional[Any] = None, + scaler: Optional[Any] = None, + device: Optional[torch.device] = None, + strict: bool = True, +) -> Dict[str, Any]: + """Load a training checkpoint. Returns dict with epoch, step, and extra_state. + + The returned 'epoch' is the epoch to resume training from (0-indexed). + If no checkpoint exists, returns epoch=0 to start fresh. + If a checkpoint exists with epoch=N, returns epoch=N+1 to resume from the next epoch. + """ + path = Path(path) + if not path.exists(): + logger.warning(f"Checkpoint not found: {path}") + return {"epoch": 0, "step": 0, "resumed": False} + + map_location = device if device else "cpu" + checkpoint = torch.load(path, map_location=map_location, weights_only=False) + + # Handle compiled model state dicts + state_dict = checkpoint.get("model_state_dict", {}) + state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} + + model.load_state_dict(state_dict, strict=strict) + logger.info(f"Loaded model state from: {path}") + + if optimizer is not None and "optimizer_state_dict" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + logger.info("Restored optimizer state") + + if scheduler is not None and "scheduler_state_dict" in checkpoint: + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + logger.info("Restored scheduler state") + + if scaler is not None and "scaler_state_dict" in checkpoint: + scaler.load_state_dict(checkpoint["scaler_state_dict"]) + logger.info("Restored scaler state") + + return { + "epoch": checkpoint.get("epoch", 0) + 1, # Resume from next epoch + "step": checkpoint.get("step", 0), + "resumed": True, + **{ + k: v + for k, v in checkpoint.items() + if k + not in [ + "model_state_dict", + "optimizer_state_dict", + "scheduler_state_dict", + "scaler_state_dict", + "epoch", + "step", + ] + }, + } + + +def load_config( + config_path: Union[str, Path], + cli_overrides: Optional[Dict[str, Any]] = None, + quiet: bool = False, +) -> DictConfig: + """Load YAML config with optional dot-notation overrides (e.g., 'model.lr': 0.001).""" + config_path = Path(config_path) + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + cfg = OmegaConf.load(config_path) + if not quiet: + logger.info(f"Loaded config from {config_path}") + + if cli_overrides: + # Convert dot notation to nested dict + override_dict = {} + for key, value in cli_overrides.items(): + keys = key.split(".") + current = override_dict + for k in keys[:-1]: + current = current.setdefault(k, {}) + current[keys[-1]] = value + + cfg = OmegaConf.merge(cfg, OmegaConf.create(override_dict)) + if not quiet: + logger.info(f"Applied {len(cli_overrides)} config overrides") + + return cfg + + +def get_checkpoints_dir() -> Path: + """Get the base checkpoints directory from EBJEPA_CKPTS env variable.""" + return Path(os.environ.get("EBJEPA_CKPTS", "checkpoints")) + + +def get_unified_experiment_dir( + example_name: str, + sweep_name: str, + exp_name: str, + seed: int, + base_dir: Union[str, Path, None] = None, + create: bool = True, +) -> Path: + """Create experiment dir: {base_dir}/{example_name}/{sweep_name}/{exp_name}_seed{seed}.""" + if base_dir is None: + base_dir = get_checkpoints_dir() + + # Convert to absolute path to avoid issues when cwd changes (e.g., after os.chdir) + exp_dir = ( + Path(base_dir) / example_name / sweep_name / f"{exp_name}_seed{seed}" + ).absolute() + + if create: + exp_dir.mkdir(parents=True, exist_ok=True) + + return exp_dir + + +def get_default_sweep_name() -> str: + return datetime.now().strftime("sweep_%Y-%m-%d_%H-%M") + + +def get_default_dev_name() -> str: + return datetime.now().strftime("dev_%Y-%m-%d_%H-%M") + + +def get_exp_name(example_name: str, cfg) -> str: + """Get short experiment name encoding key hyperparameters (seed appended separately).""" + if example_name == "image_jepa": + proj = "proj" if cfg.model.use_projector else "noproj" + parts = [ + cfg.model.type, + cfg.loss.type, + proj, + f"bs{cfg.data.batch_size}", + f"ep{cfg.optim.epochs}", + ] + if cfg.model.use_projector: + parts.append(f"ph{cfg.model.proj_hidden_dim}") + parts.append(f"po{cfg.model.proj_output_dim}") + if cfg.loss.type == "vicreg": + parts.append(f"std{cfg.loss.std_coeff}") + parts.append(f"cov{cfg.loss.cov_coeff}") + elif cfg.loss.type == "bcs": + parts.append(f"lmbd{cfg.loss.lmbd}") + return "_".join(str(p) for p in parts) + elif example_name == "video_jepa": + return ( + f"resnet_bs{cfg.data.batch_size}" + f"_lr{cfg.optim.lr}" + f"_std{cfg.loss.std_coeff}" + f"_cov{cfg.loss.cov_coeff}" + ) + elif example_name == "ac_video_jepa": + return ( + f"{cfg.model.encoder_architecture}" + f"_cov{cfg.model.regularizer.cov_coeff}" + f"_std{cfg.model.regularizer.std_coeff}" + f"_simt{cfg.model.regularizer.get('sim_coeff_t')}" + f"_idm{cfg.model.regularizer.get('idm_coeff')}" + ) + else: + return "exp" + + +def format_metrics(metrics: Dict[str, float], precision: int = 4) -> str: + """Format metrics dict as 'loss=0.1234 | acc=95.12'.""" + parts = [] + for k, v in metrics.items(): + if isinstance(v, float): + parts.append(f"{k}={v:.{precision}f}") + else: + parts.append(f"{k}={v}") + return " | ".join(parts) + + +def log_epoch( + epoch: int, + metrics: Dict[str, float], + total_epochs: Optional[int] = None, + elapsed_time: Optional[float] = None, +) -> None: + """Log epoch summary: πŸ“Š [Epoch 001/100] metric1=val1 | metric2=val2 | time=123.4s.""" + if total_epochs: + prefix = f"[Epoch {epoch:03d}/{total_epochs}]" + else: + prefix = f"[Epoch {epoch:03d}]" + + metrics_str = format_metrics(metrics) + + if elapsed_time is not None: + logger.info(f"πŸ“Š {prefix} {metrics_str} | time={elapsed_time:.1f}s") + else: + logger.info(f"πŸ“Š {prefix} {metrics_str}") + + +def log_model_info(model: nn.Module, param_counts: Dict[str, int]) -> None: + """Log model structure and parameter counts.""" + logger.info(f"🧠 Model:\n{model}") + param_str = " | ".join(f"{k}={v:,}" for k, v in param_counts.items()) + logger.info(f"πŸ”’ Parameters: {param_str}") + + +def log_data_info( + dataset_name: str, + num_batches: int, + batch_size: int, + train_samples: Optional[int] = None, + val_samples: Optional[int] = None, +) -> None: + """Log dataset information.""" + if train_samples is not None and val_samples is not None: + logger.info( + f"πŸ“¦ Data: {dataset_name} | {num_batches} batches x {batch_size} samples | " + f"train={train_samples:,} | val={val_samples:,}" + ) + else: + logger.info( + f"πŸ“¦ Data: {dataset_name} | {num_batches} batches x {batch_size} samples" + ) + + +def log_config(cfg: Union[Dict, DictConfig], title: str = "Run Configuration") -> None: + """Log configuration in a readable format.""" + logger.info("=" * 60) + logger.info(f"βš™οΈ {title}:") + logger.info("=" * 60) + + if isinstance(cfg, DictConfig): + cfg = OmegaConf.to_container(cfg, resolve=True) + + for section, values in cfg.items(): + if isinstance(values, dict): + for key, value in values.items(): + logger.info(f" {section}.{key}={value}") + else: + logger.info(f" {section}={values}") + logger.info("=" * 60) diff --git a/eb_jepa/vis_utils.py b/eb_jepa/vis_utils.py new file mode 100644 index 0000000..293c87f --- /dev/null +++ b/eb_jepa/vis_utils.py @@ -0,0 +1,776 @@ +import os +from typing import List, Optional, Union + +import cv2 +import imageio +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +import torch + +from eb_jepa.logging import get_logger + +logger = get_logger(__name__) + +FIGSIZE_BASE = (4.0, 3.0) + + +# ============================================================================= +# Frame Processing Primitives +# ============================================================================= + + +def to_numpy(frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray: + """Convert a tensor or array to numpy array.""" + if isinstance(frame, torch.Tensor): + return frame.detach().cpu().numpy() + return np.asarray(frame) + + +def to_uint8(frame: np.ndarray) -> np.ndarray: + """Convert frame to uint8, handling both [0,1] and [0,255] ranges.""" + if frame.dtype == np.uint8: + return frame + if frame.max() <= 1.0: + return (frame * 255).astype(np.uint8) + return frame.astype(np.uint8) + + +def to_hwc(frame: np.ndarray) -> np.ndarray: + """Convert frame from (C, H, W) to (H, W, C) format if needed.""" + if ( + frame.ndim == 3 + and frame.shape[0] in [1, 2, 3] + and frame.shape[0] < frame.shape[1] + ): + return frame.transpose(1, 2, 0) + return frame + + +def expand_channels(frame: np.ndarray, target_channels: int = 3) -> np.ndarray: + """Expand frame to target number of channels (default 3 for RGB).""" + if frame.ndim == 2: # Grayscale (H, W) + return np.stack([frame] * target_channels, axis=-1) + if frame.ndim == 3 and frame.shape[-1] < target_channels: + h, w, c = frame.shape + expanded = np.zeros((h, w, target_channels), dtype=frame.dtype) + expanded[..., :c] = frame + return expanded + return frame + + +def prepare_frame(frame: Union[torch.Tensor, np.ndarray, None]) -> Optional[np.ndarray]: + """Convert any frame format to numpy uint8 (H, W, C) with 3 channels.""" + if frame is None: + return None + frame = to_numpy(frame) + frame = to_hwc(frame) + frame = to_uint8(frame) + frame = expand_channels(frame) + return frame + + +def add_border( + frame: np.ndarray, color: tuple = (255, 0, 0), width: int = 2 +) -> np.ndarray: + """Add a colored border around a frame.""" + bordered = frame.copy() + bordered[:width, :] = color + bordered[-width:, :] = color + bordered[:, :width] = color + bordered[:, -width:] = color + return bordered + + +def add_text_overlay( + frame: np.ndarray, + text: str, + position: str = "top_right", + color: tuple = (255, 255, 255), + font_scale: float = None, + thickness: int = None, +) -> np.ndarray: + """Add text overlay to a frame with auto-scaled font size.""" + h, w = frame.shape[:2] + scale_factor = min(h, w) / 1000 + if font_scale is None: + font_scale = max(0.2, scale_factor * 0.5) + if thickness is None: + thickness = max(1, int(scale_factor)) + margin = int(h * 0.02) + + (text_width, text_height), _ = cv2.getTextSize( + text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness + ) + + if position == "top_right": + text_x = w - text_width - margin + text_y = text_height + margin + elif position == "top_left": + text_x = margin + text_y = text_height + margin + elif position == "bottom_right": + text_x = w - text_width - margin + text_y = h - margin + else: # bottom_left + text_x = margin + text_y = h - margin + + cv2.putText( + frame, + text, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + color, + thickness, + cv2.LINE_AA, + ) + return frame + + +# ============================================================================= +# Sequence Processing +# ============================================================================= + + +def frames_to_list(frames) -> List[List[np.ndarray]]: + """ + Convert various frame input formats to List[List[np.ndarray]]. + + Supported inputs: + - List of (H, W, C) arrays -> single sequence + - (T, H, W, C) array -> single sequence + - (B, T, H, W, C) array -> B sequences + - List of Lists -> B sequences + """ + if isinstance(frames, np.ndarray): + if frames.ndim == 4: # (T, H, W, C) + return [list(frames)] + elif frames.ndim == 5: # (B, T, H, W, C) + return [list(frames[b]) for b in range(frames.shape[0])] + raise ValueError(f"Unsupported array shape: {frames.shape}") + + if isinstance(frames, list): + if len(frames) == 0: + raise ValueError("Empty frames list") + first = frames[0] + if isinstance(first, (np.ndarray, torch.Tensor)): + first_np = to_numpy(first) if isinstance(first, torch.Tensor) else first + if first_np.ndim == 3: # List of (H, W, C) + return [ + [to_numpy(f) if isinstance(f, torch.Tensor) else f for f in frames] + ] + elif first_np.ndim == 4: # List of (T, H, W, C) + return [list(seq) for seq in frames] + raise ValueError(f"Unsupported frame shape: {first_np.shape}") + elif isinstance(first, list): + return frames + raise ValueError(f"Unsupported frame type: {type(first)}") + + raise ValueError(f"Unsupported frames type: {type(frames)}") + + +def select_frame_indices( + total: int, num_frames: int = None, indices: List[int] = None +) -> List[int]: + """Select evenly-spaced frame indices or use provided indices.""" + if indices is not None: + return list(indices) + if num_frames is None or num_frames >= total: + return list(range(total)) + return np.linspace(0, total - 1, num_frames, dtype=int).tolist() + + +# ============================================================================= +# GIF/Video Saving +# ============================================================================= + + +def save_gif( + tensor: torch.Tensor, + save_path: str, + fps: int = 10, + show_frame_numbers: bool = False, + init_frame=None, + goal_frame=None, + upscale_factor: int = 2, +): + """ + Save a (T, C, H, W) tensor as GIF and PDF with horizontal unrolling. + + Args: + tensor: Tensor of shape (T, C, H, W) uint8 + save_path: Path to save the GIF file + fps: Frames per second for the GIF + show_frame_numbers: Whether to overlay frame numbers on GIF frames + init_frame: Optional initial state frame for PDF (with red border) + goal_frame: Optional goal state frame for PDF (with red border) + upscale_factor: Factor to upscale frames for better text readability in GIF + """ + total_frames = tensor.shape[0] + images = [] + images_original = [] + + for i in range(total_frames): + img = prepare_frame(tensor[i]) + images_original.append(img) + + if show_frame_numbers: + h, w = img.shape[:2] + img_upscaled = cv2.resize( + img, + (w * upscale_factor, h * upscale_factor), + interpolation=cv2.INTER_NEAREST, + ) + # Use larger font for better readability + img_upscaled = add_text_overlay( + img_upscaled, + f"Frame {i+1}/{total_frames}", + "top_right", + font_scale=0.5, + thickness=2, + ) + images.append(img_upscaled) + else: + images.append(img) + + imageio.mimsave(save_path, images, fps=fps, loop=0) + + # Also save as PDF with horizontal unrolling (using matplotlib text overlay) + pdf_path = save_path.replace(".gif", "_unroll.pdf") + frame_labels = ( + [f"{i+1}/{total_frames}" for i in range(total_frames)] + if show_frame_numbers + else None + ) + save_gif_as_pdf_unroll( + images_original, + pdf_path, + num_frames=min(8, total_frames), + figsize_per_frame=(0.8, 0.8), + init_frame=init_frame, + goal_frame=goal_frame, + frame_labels=frame_labels, + ) + + +def save_gif_HWC(frames_list: List, save_path: str, fps: int = 10): + """Save a list of (H, W, C) frames as a GIF.""" + images = [prepare_frame(f) for f in frames_list] + imageio.mimsave(save_path, images, fps=fps, loop=0) + + +# ============================================================================= +# PDF Unrolling +# ============================================================================= + + +def save_gif_as_pdf_unroll( + frames, + save_path: str, + num_frames: int = None, + frame_indices: List[int] = None, + row_labels: List[str] = None, + title: str = None, + figsize_per_frame: tuple = (1.0, 1.0), + dpi: int = 300, + init_frame=None, + goal_frame=None, + frame_labels: List[str] = None, +): + """ + Save frames as a PDF with horizontal unrolling. + + Args: + frames: Frames in various formats (see frames_to_list for supported formats) + save_path: Path to save the PDF + num_frames: Number of evenly-spaced frames to include + frame_indices: Specific frame indices (overrides num_frames) + row_labels: Labels for each row + title: Optional figure title + figsize_per_frame: Size per frame in inches + dpi: Resolution + init_frame: Initial frame with red border (left side) + goal_frame: Goal frame with red border (right side) + frame_labels: Labels for each frame (e.g., "1/10", "2/10", ...) rendered as + high-resolution matplotlib text overlay + """ + sequences = frames_to_list(frames) + num_rows = len(sequences) + total_frames_per_seq = len(sequences[0]) + + selected_indices = select_frame_indices( + total_frames_per_seq, num_frames, frame_indices + ) + + # Prepare init/goal frames + init_prepared = prepare_frame(init_frame) + goal_prepared = prepare_frame(goal_frame) + if init_prepared is not None: + init_prepared = add_border(init_prepared) + if goal_prepared is not None: + goal_prepared = add_border(goal_prepared) + + # Calculate columns + num_episode_cols = len(selected_indices) + has_init = init_prepared is not None + has_goal = goal_prepared is not None + num_cols = num_episode_cols + int(has_init) + int(has_goal) + + # Create figure + fig_width = figsize_per_frame[0] * num_cols + fig_height = figsize_per_frame[1] * num_rows + if title: + fig_height += 0.3 + + fig, axes = plt.subplots( + num_rows, num_cols, figsize=(fig_width, fig_height), dpi=dpi, squeeze=False + ) + plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) + + for row_idx, sequence in enumerate(sequences): + col_offset = 0 + + # Init frame + if has_init: + axes[row_idx, 0].imshow(init_prepared) + axes[row_idx, 0].axis("off") + axes[row_idx, 0].set_aspect("equal") + col_offset = 1 + + # Episode frames + for col_idx, frame_idx in enumerate(selected_indices): + ax = axes[row_idx, col_offset + col_idx] + frame = prepare_frame(sequence[frame_idx]) + ax.imshow(frame, cmap="gray" if frame.ndim == 2 else None) + ax.axis("off") + ax.set_aspect("equal") + + # Use matplotlib text for high-resolution overlay (row label on first frame) + if col_idx == 0 and row_labels and row_idx < len(row_labels): + ax.text( + 0.02, + 0.98, + row_labels[row_idx], + transform=ax.transAxes, + fontsize=8, + color="white", + verticalalignment="top", + horizontalalignment="left", + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.1", facecolor="black", alpha=0.5), + ) + + # Add frame label using matplotlib text (high-resolution) + if frame_labels and frame_idx < len(frame_labels): + ax.text( + 0.98, + 0.98, + f"Frame {frame_labels[frame_idx]}", + transform=ax.transAxes, + fontsize=6, + color="white", + verticalalignment="top", + horizontalalignment="right", + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.1", facecolor="black", alpha=0.5), + ) + + # Goal frame + if has_goal: + axes[row_idx, num_cols - 1].imshow(goal_prepared) + axes[row_idx, num_cols - 1].axis("off") + axes[row_idx, num_cols - 1].set_aspect("equal") + + if title: + fig.suptitle(title, fontsize=12, y=1.02) + + plt.savefig(save_path, bbox_inches="tight", dpi=dpi, format="pdf", pad_inches=0.0) + plt.close(fig) + logger.info(f"PDF unroll figure saved to {save_path}") + return save_path + + +# ============================================================================= +# Image Grid Display +# ============================================================================= + + +def show_images( + tensor, + nrow: int = 4, + titles: List[str] = None, + labels: List[str] = None, + save_path: str = None, + dpi: int = 150, + close_fig: bool = True, + first_channel_only: bool = True, + clamp: bool = True, +): + """ + Display and optionally save a grid of images from a PyTorch tensor. + + Args: + tensor: Input tensor of shape (B, C, H, W) or (B, T, C, H, W) + nrow: Number of images per row + titles: List of titles for each image + labels: List of labels for each image + save_path: Path to save figure + dpi: Resolution + close_fig: Whether to close figure after saving + first_channel_only: Keep only first channel + clamp: Clamp values to [0, 1] + """ + tensor = to_numpy(tensor) + + if tensor.ndim == 5: + tensor = tensor[:, 0] + if tensor.ndim == 4 and first_channel_only: + tensor = tensor[:, 0:1] + if clamp: + tensor = np.clip(tensor, 0, 1) + + batch_size = tensor.shape[0] + ncol = min(nrow, batch_size) + nrow_actual = (batch_size + ncol - 1) // ncol + + fig, axes = plt.subplots( + nrow_actual, ncol, figsize=(ncol * 2, nrow_actual * 2), dpi=dpi + ) + if nrow_actual == 1 and ncol == 1: + axes = [[axes]] + + for i, ax in enumerate(axes.flat): + if i >= batch_size: + ax.axis("off") + continue + img = tensor[i].squeeze() + if img.ndim == 3 and img.shape[0] < 3: + img = expand_channels(img.transpose(1, 2, 0)) + ax.imshow(img, cmap="gray" if img.ndim == 2 else None) + ax.axis("off") + if titles: + ax.set_title(titles[i], fontsize=10) + if labels: + ax.text( + 0.5, + -0.15, + labels[i], + ha="center", + va="center", + transform=ax.transAxes, + fontsize=8, + ) + + plt.tight_layout() + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=dpi) + if not close_fig or not save_path: + plt.show() + if close_fig: + plt.close(fig) + + +# ============================================================================= +# Comparison Visualizations +# ============================================================================= + + +def create_comparison_gif( + gt_seq, + pred_seq_true, + pred_seq_random, + gt_dec=None, + save_path: str = "comparison.gif", + fps: int = 15, + upscale_factor: int = 2, +): + """ + Create a comparison GIF visualization with multiple sequences. + + Args: + gt_seq: [B, T, H, W, C] Ground truth sequence + gt_dec: [B, T, H, W, C] Decoded ground truth (optional) + pred_seq_true: [B, T, H, W, C] Predictions with true actions + pred_seq_random: [B, T, H, W, C] Predictions with random actions + save_path: Output path + fps: Frames per second + upscale_factor: Factor to upscale frames for better text readability + """ + b = gt_seq.shape[0] + num_rows = min(b, 4) + + seqs = [gt_seq, pred_seq_true, pred_seq_random] + if gt_dec is not None: + seqs.insert(1, gt_dec) + seq_length = min(s.shape[1] for s in seqs) + + img_height, img_width = gt_seq.shape[2], gt_seq.shape[3] + num_cols = len(seqs) + + # Upscaled dimensions for better text rendering + up_img_height = img_height * upscale_factor + up_img_width = img_width * upscale_factor + title_height = 30 * upscale_factor # Scale title area proportionally + + titles = ["GT"] + if gt_dec is not None: + titles.append("Dec GT") + titles.extend(["GT Act", "Rand Act"]) + + frames = [] + for t in range(seq_length): + canvas = np.zeros( + (title_height + num_rows * up_img_height, num_cols * up_img_width, 3), + dtype=np.uint8, + ) + + # Column titles with larger font + font_scale = 0.4 * upscale_factor + thickness = max(1, upscale_factor) + for col, title in enumerate(titles): + col_x = col * up_img_width + up_img_width // 2 + (tw, _), _ = cv2.getTextSize( + title, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness + ) + cv2.putText( + canvas, + title, + (col_x - tw // 2, title_height - 10), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + (255, 255, 255), + thickness, + cv2.LINE_AA, + ) + + # Frames (upscaled) + for row in range(num_rows): + base_y = title_height + row * up_img_height + for col, seq in enumerate( + seqs + if gt_dec is None + else [gt_seq, gt_dec, pred_seq_true, pred_seq_random] + ): + frame = prepare_frame(seq[row, t]) + frame_upscaled = cv2.resize( + frame, + (up_img_width, up_img_height), + interpolation=cv2.INTER_NEAREST, + ) + col_x = col * up_img_width + canvas[ + base_y : base_y + up_img_height, col_x : col_x + up_img_width + ] = frame_upscaled + + # Timestep indicator with larger font + add_text_overlay(canvas, f"t={t}", "bottom_right", font_scale=1.0, thickness=2) + frames.append(canvas) + + imageio.mimsave(save_path, frames, fps=fps, loop=0) + logger.info(f" βœ“ Saved comparison GIF: {os.path.basename(save_path)}") + + # Also save PDF with GT and GT Act rows + pdf_path = save_path.replace(".gif", "_unroll.pdf") + pdf_sequences = [ + [prepare_frame(gt_seq[0, t]) for t in range(seq_length)], + [prepare_frame(pred_seq_true[0, t]) for t in range(seq_length)], + ] + row_labels = ["GT", "GT Act"] + + save_gif_as_pdf_unroll( + pdf_sequences, + pdf_path, + num_frames=min(8, seq_length), + figsize_per_frame=(0.8, 0.8), + row_labels=row_labels, + ) + + return frames + + +def save_decoded_frames( + pred_frames_over_iterations: List, + costs: List[float], + plan_vis_path: str, + overlay: bool = True, +): + """ + Save decoded frames from planning iterations as a GIF. + + Args: + pred_frames_over_iterations: List of (T, H, W, C) arrays + costs: List of costs per iteration + plan_vis_path: Path prefix for outputs + overlay: Whether to add iteration overlay + """ + if pred_frames_over_iterations is None or plan_vis_path is None: + return + + frames = [] + for i, pred_frames in enumerate(pred_frames_over_iterations): + for frame in pred_frames: + frame_copy = frame.copy() + if overlay: + add_text_overlay(frame_copy, f"Iter {i+1}", "top_left", (200, 200, 200)) + frames.append(frame_copy) + + save_gif_HWC(frames, f"{plan_vis_path}.gif", fps=30) + logger.info(f"Plan decoding video saved to {plan_vis_path}") + + # Save last iteration as PDF + last_frames = pred_frames_over_iterations[-1] + save_gif_as_pdf_unroll( + list(last_frames), + f"{plan_vis_path}_unroll.pdf", + num_frames=min(8, len(last_frames)), + figsize_per_frame=(0.8, 0.8), + ) + + +# ============================================================================= +# Analysis & Plotting +# ============================================================================= + + +def plot_distances( + data, + save_path: str, + figsize: tuple = (4.0, 3.0), + xlabel: str = "Timesteps", + ylabel: str = "Distance to goal", +): + """Plot a line chart and save to file.""" + plt.figure(figsize=figsize, dpi=300) + sns.lineplot(data=data) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.tight_layout() + plt.savefig(save_path, bbox_inches="tight") + plt.close() + + +def compute_embed_differences(all_encs: torch.Tensor) -> torch.Tensor: + """Compute MSE differences from goal (last encoding).""" + sq_diff = (all_encs[:-1] - all_encs[-1:]) ** 2 + return sq_diff.mean(dim=tuple(range(1, all_encs.ndim))) + + +def analyze_distances( + obses: torch.Tensor, + infos: List[dict], + plot_prefix: str, + goal_position: torch.Tensor, + goal_state: torch.Tensor, + normalizer, + model, + objective, + device: torch.device, +): + """Analyze distances between observations and goal, generate plots.""" + coords = torch.stack( + [ + ( + torch.as_tensor(x["dot_position"]) + if not isinstance(x["dot_position"], torch.Tensor) + else x["dot_position"] + ) + for x in infos + ] + ).unsqueeze(1) + + distances = ( + torch.norm(coords[..., -1, :3] - goal_position[:3].unsqueeze(0), dim=-1) + .detach() + .cpu() + ) + + sns.set_theme() + figsize = (4.0, 3.0) + plot_distances(distances, plot_prefix + "_distances.pdf", figsize=figsize) + + all_states = ( + normalizer.normalize_state(torch.cat([obses, goal_state.unsqueeze(0)])) + .unsqueeze(-3) + .to(device) + ) + all_encs = model.encode(all_states) + diffs = compute_embed_differences(all_encs).detach().cpu() + + plot_distances( + diffs, + plot_prefix + "_rep_distance_visual.pdf", + figsize=figsize, + xlabel="Timesteps", + ylabel="Rep distance to goal", + ) + + all_objectives = objective(all_encs[:-1]).detach().cpu() + plot_distances( + all_objectives, + plot_prefix + "_objectives.pdf", + figsize=figsize, + xlabel="Timesteps", + ylabel="Objective values", + ) + + return distances, diffs + + +def plot_losses( + losses: List[torch.Tensor], + elite_losses_mean: List[torch.Tensor], + elite_losses_std: List[torch.Tensor], + work_dir, + num_act_stepped: int = 1, + frameskip: int = 1, +): + """Plot losses over optimization steps.""" + if not losses: + return + + losses_arr = torch.stack(losses, dim=0).detach().cpu().numpy() + elite_mean_arr = torch.stack(elite_losses_mean, dim=0).detach().cpu().numpy() + elite_std_arr = torch.stack(elite_losses_std, dim=0).detach().cpu().numpy() + n_timesteps, n_opt_steps, n_losses = losses_arr.shape + + sns.set_theme() + for i in range(n_losses): + total_plots = min(16, n_timesteps) + cols = int(np.ceil(total_plots)) + fig_width = FIGSIZE_BASE[0] * cols + fig_height = FIGSIZE_BASE[1] + + plt.figure(figsize=(fig_width, fig_height), dpi=300) + steps = np.linspace(0, n_timesteps - 1, total_plots, dtype=int) + + for j, step in enumerate(steps): + ax = plt.subplot(1, cols, j + 1) + if n_opt_steps > 1: + sns.lineplot(data=losses_arr[step, :, i]) + sns.lineplot(data=elite_mean_arr[step, :, i]) + ax.fill_between( + range(n_opt_steps), + elite_mean_arr[step, :, i] - elite_std_arr[step, :, i], + elite_mean_arr[step, :, i] + elite_std_arr[step, :, i], + alpha=0.3, + ) + else: + ax.bar(0, losses_arr[step, 0, i]) + ax.bar(0, elite_mean_arr[step, 0, i]) + ax.errorbar( + 0, + elite_mean_arr[step, 0, i], + yerr=elite_std_arr[step, 0, i], + fmt="none", + capsize=5, + ) + + ax.set_title(f"Step {step * frameskip * num_act_stepped}") + ax.tick_params(axis="both") + ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) + + plt.tight_layout() + plt.savefig(work_dir / f"losses_{i}.pdf", bbox_inches="tight") + plt.close() diff --git a/eb_jepa/visualize_samples.py b/eb_jepa/visualize_samples.py deleted file mode 100644 index 95642e0..0000000 --- a/eb_jepa/visualize_samples.py +++ /dev/null @@ -1,232 +0,0 @@ -import cv2 -import imageio -import matplotlib.pyplot as plt -import numpy as np -import torch - - -def show_images( - tensor, - nrow=4, - titles=None, - labels=None, - save_path=None, - dpi=150, - close_fig=True, - first_channel_only=True, - clamp=True, -): - """ - Display and optionally save a grid of images from a PyTorch tensor - Args: - tensor: Input tensor of shape (B, C, H, W) or (B, T, C, H, W) - nrow: Number of images per row in the grid - titles: List of titles for each image - labels: List of labels for each image - save_path: Path to save figure (None to disable saving) - dpi: Resolution for saved figure - close_fig: Whether to close figure after saving/displaying - """ - # Convert to CPU and detach from computation graph - if isinstance(tensor, torch.Tensor): - tensor = tensor.detach().cpu() - - # Handle 5D tensors (batch, time, channel, height, width) - if tensor.ndim == 5: - tensor = tensor[:, 0] # Use first frame for visualization - - # Add channel dimension handling - if tensor.ndim == 4 and first_channel_only: - tensor = tensor[:, 0:1] # Keep only first channel - - # Convert to numpy and denormalize (assuming [0,1] range) - if clamp: - tensor = tensor.clamp(0, 1).numpy() - - # Create plot - batch_size = tensor.shape[0] - ncol = min(nrow, batch_size) - nrow = (batch_size + ncol - 1) // ncol - - fig, axes = plt.subplots(nrow, ncol, figsize=(ncol * 2, nrow * 2), dpi=dpi) - if nrow == 1 and ncol == 1: - axes = [[axes]] # Ensure 2D array for single image - - for i, ax in enumerate(axes.flat): - if i >= batch_size: - ax.axis("off") - continue - img = tensor[i].squeeze() - if img.ndim == 3 and img.shape[0] < 3: - h, w = img.shape[1], img.shape[2] - rgb_img = np.zeros((h, w, 3)) - for c in range(min(img.shape[0], 3)): - rgb_img[..., c] = img[c] - img = rgb_img.astype(np.uint8) - ax.imshow(img) - else: - ax.imshow(img, cmap="gray" if img.ndim == 2 else None) - ax.axis("off") - if titles: - ax.set_title(titles[i], fontsize=10) - if labels: - ax.text( - 0.5, - -0.15, - labels[i], - ha="center", - va="center", - transform=ax.transAxes, - fontsize=8, - ) - - plt.tight_layout() - - # Save figure if path specified - if save_path: - plt.savefig(save_path, bbox_inches="tight", dpi=dpi) - - # Display figure if not closed - if not close_fig or not save_path: - plt.show() - - # Clean up - if close_fig: - plt.close(fig) - - -def save_gif(tensor, save_path, fps=10, show_frame_numbers=False): - """ - tensor of shape (T, C, H, W) uint8 - """ - - # Save each frame as an image - images = [] - total_frames = tensor.shape[0] - for i in range(total_frames): - img = tensor[i].numpy().astype(np.uint8) - if img.ndim == 3 and img.shape[0] < 3: # Handle fewer than 3 channels - h, w = img.shape[1], img.shape[2] - rgb_img = np.zeros((h, w, 3)) - for c in range(min(img.shape[0], 3)): - rgb_img[..., c] = img[c] - img = rgb_img.astype(np.uint8) - else: - img = img.squeeze() - if show_frame_numbers: - # Calculate scale factors based on image dimensions - h, w = img.shape[0], img.shape[1] - scale_factor = min(h, w) / 1000 # Base scale on 500px reference - font_scale = max(0.2, scale_factor * 0.5) - thickness = max(1, int(scale_factor)) - margin = int(h * 0.02) # 3% of height - - # Add frame number text - text = f"Frame {i+1}/{total_frames}" - - # Get text size to position it at top right with margin - (text_width, text_height), _ = cv2.getTextSize( - text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness - ) - text_x = w - text_width - margin - text_y = text_height + margin - - cv2.putText( - img, - text, - (text_x, text_y), - cv2.FONT_HERSHEY_SIMPLEX, - font_scale, - (255, 255, 255), - thickness, - cv2.LINE_AA, - ) - - images.append(img) - - # Save as GIF - imageio.mimsave(save_path, images, fps=fps, loop=0) - - -def to3channels(input_array, channel_dim=2): - """ - Convert a tensor or numpy array with fewer than 3 channels to 3 channels by adding zeros. - Preserves the input dtype and type (numpy array or PyTorch tensor). - - Args: - input_array: Input tensor or numpy array of shape ... C ... at idx channel_dim - where C can be < 3. - channel_dim: Dimension index where channels are located (default is 2 for (H, W, C) format) - - Returns: - A tensor or numpy array of the same type and dtype as input_array, but with 3 channels. - """ - is_torch_tensor = isinstance(input_array, torch.Tensor) - ndim = input_array.ndim - if ndim < channel_dim + 1: - raise ValueError(f"Input must have at least {channel_dim + 1} dimensions.") - shape = list(input_array.shape) - if shape[channel_dim] >= 3: - return input_array - new_shape = shape.copy() - new_shape[channel_dim] = 3 - - if is_torch_tensor: - new_array = torch.zeros( - new_shape, dtype=input_array.dtype, device=input_array.device - ) - else: - new_array = np.zeros(new_shape, dtype=input_array.dtype) - - # Create a dynamic slice for the channel dimension - indices = [slice(None)] * ndim - indices[channel_dim] = slice(0, shape[channel_dim]) - new_array[tuple(indices)] = input_array - - return new_array - - -def save_gif_HWC(frames_list, save_path, fps=10): - """ - Save a list of image tensors as a GIF. - - Args: - frames_list: List of tensors, each with shape (H, W, C) where C can be < 3 - save_path: Path to save the GIF - fps: Frames per second - """ - images = [] - - for frame in frames_list: - # Convert to numpy if it's a tensor - if isinstance(frame, torch.Tensor): - img = frame.detach().cpu().numpy() - else: - img = np.array(frame) - - # Ensure uint8 type - if img.dtype != np.uint8: - if img.max() <= 1.0: - img = (img * 255).astype(np.uint8) - else: - img = img.astype(np.uint8) - - # Handle case where C < 3 - if img.ndim == 3 and img.shape[2] < 3: - h, w = img.shape[0], img.shape[1] - rgb_img = np.zeros((h, w, 3), dtype=np.uint8) - for c in range(min(img.shape[2], 3)): - rgb_img[..., c] = img[..., c] - images.append(rgb_img) - # Handle grayscale case (H, W) - elif img.ndim == 2: - h, w = img.shape - rgb_img = np.stack([img, img, img], axis=2) - images.append(rgb_img) - else: - images.append(img) - - # Save as GIF - imageio.mimsave(save_path, images, fps=fps, loop=0) - - diff --git a/examples/__init__.py b/examples/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/ac_video_jepa/README.md b/examples/ac_video_jepa/README.md index 6fdf422..c99b5c6 100644 --- a/examples/ac_video_jepa/README.md +++ b/examples/ac_video_jepa/README.md @@ -51,7 +51,7 @@ $$\mathcal{L}_{\mathrm{var}} = \frac{1}{HD} \sum^H_{t=0} \sum^D_{j=0} \mathrm{ma $$C(Z_t) = \frac{1}{N-1}(Z_t-\bar{Z_t})^\top(Z_t-\bar{Z_t}), \ \bar{Z} = \frac{1}{N} \sum^N_{b=1} Z_{t,b}$$ -$$\mathcal{L}_{\mathrm{cov}} = \frac{1}{H} \sum^{H}_{t=0} \frac{1}{D} \sum_{i \neq j} [C(Z_t)]^2_{i,j}$$ +$$\mathcal{L}_{\mathrm{cov}} = \frac{1}{H} \sum^{H}_{t=0} \frac{1}{D(D-1)} \sum_{i \neq j} [C(Z_t)]^2_{i,j}$$ $$\mathcal{L}_{\mathrm{IDM}} = \sum^H_{t=0} \frac{1}{N} \sum^N_{b=0} \| a_{t,b} - \mathrm{MLP}(Z_{(t,b)}, Z_{(t+1,b)}) \|^2_2$$ @@ -72,25 +72,24 @@ We study two setups: ### Usage ```bash -# Train a model +# Train a model locally python -m examples.ac_video_jepa.main \ - --fname examples/ac_video_jepa/cfgs/.yaml + --fname examples/ac_video_jepa/cfgs/train.yaml -# Launch sbatch single training -python -m examples.ac_video_jepa.launch_sbatch \ - --fname examples/ac_video_jepa/cfgs/.yaml \ +# Launch 3 seeds with automatic wandb averaging (recommended) +python -m examples.launch_sbatch --example ac_video_jepa + +# Launch 3 seeds with custom sweep name +python -m examples.launch_sbatch --example ac_video_jepa --sweep my_experiment +``` + +See the main [README](../../README.md) for wandb seed averaging and sweep UI instructions. -# Run planning evaluation of specific checkpoint +```bash +# Run planning evaluation of a trained model python -m examples.ac_video_jepa.main \ - --fname examples/ac_video_jepa/cfgs/.yaml \ --meta.model_folder /path/to/trained/model \ - --meta.plan_eval_only_mode True - -# Run train hyperparameter sweep -python -m examples.ac_video_jepa.launch_sbatch \ - --sweep /sweep/common/folder \ - --fname examples/ac_video_jepa/cfgs/.yaml \ - --use_wandb_sweep + --meta.eval_only_mode True ``` ## Evaluation @@ -151,7 +150,7 @@ The unrolling of 90 actions by our best model is illustrated in the below figure | *Random wall train and eval* | ### Planning -In all the below tables, we first obtain success rates as an average over $N=20$ planning episodes. For each model, we launch 3 training seeds, oveer which we average success rate. To account for variability across a single run, we also average the success rate of the last 3 training epochs. We display the **std over 3 seeds and the 3 last epoch checkpoints** for all below sections. +In all the below tables, we first obtain success rates as an average over $N=20$ planning episodes. For each model, we launch 3 training seeds, over which we average success rate. To account for variability across a single run, we also average the success rate of the last 3 training epochs. We display the **std over 3 seeds and the 3 last epoch checkpoints** for all below sections. Our best model gets 97% Success in the Random Wall setup. @@ -159,8 +158,8 @@ Our best model gets 97% Success in the Random Wall setup. |-------------------|---------|------------------| | Impala - RNN| MPPI | $97 \pm 2$ | -#### Visualisation -The below figure shows a successful planning episode with the MPPI planner, a model trained on the fix wall and evaluated on the same wall position. +#### Visualization +The below figure shows a successful planning episode with the MPPI planner, a model trained on the fixed wall and evaluated on the same wall position. | Planning Episode | Task Definition | |------------------|-----------------| @@ -208,13 +207,15 @@ Key insights: ## Experiment tracking We encourage to use the extensive integration of wandb logging in this `ac_video_jepa` example. -To reproduce the below plot, launch a training sweep with the below values for the regularization loss coefficients, hardcoding them in `examples/ac_video_jepa/launch_sbatch.py` and launching +To reproduce the below plot, launch a full hyperparameter sweep with the `--full-sweep` flag: ``` python -m examples.ac_video_jepa.launch_sbatch \ - --sweep /sweep/common/folder \ + --sweep \ --fname examples/ac_video_jepa/cfgs/train.yaml \ - --use_wandb_sweep + --full-sweep \ + --use-wandb-sweep ``` +This sweeps over the following regularization loss coefficients and seeds: | $\beta$ | $\alpha$ | $\delta$ | $\omega$ | |---------|---------|---------|--------| diff --git a/examples/ac_video_jepa/assets/train_plan_schema_crop.png b/examples/ac_video_jepa/assets/train_plan_schema_crop.png index 0287ef2..2c29685 100644 Binary files a/examples/ac_video_jepa/assets/train_plan_schema_crop.png and b/examples/ac_video_jepa/assets/train_plan_schema_crop.png differ diff --git a/examples/ac_video_jepa/cfgs/eval.yaml b/examples/ac_video_jepa/cfgs/eval.yaml index 6112a7c..9335617 100644 --- a/examples/ac_video_jepa/cfgs/eval.yaml +++ b/examples/ac_video_jepa/cfgs/eval.yaml @@ -6,44 +6,4 @@ env: level: normal data: env_name: two_rooms - action_noise: 1 - action_angle_noise: 0.2 - action_step_mean: 1.0 - action_step_std: 0.4 - action_lower_bd: 0.2 - action_upper_bd: 1.8 - batch_size: 64 - device: cpu - dot_std: 1.3 - border_wall_loc: 5 - fix_wall_batch_k: - fix_wall: false # false to randomize both wall and door loc - fix_door_location: 18 - fix_wall_location: 32 - exclude_wall_train: '' - exclude_door_train: '' - only_wall_val: '' - only_door_val: '' - wall_padding: 20 - door_padding: 10 - wall_width: 3 - door_space: 4 - num_train_layouts: -1 - cross_wall_rate: 0.35 - dup_traj_rate: 0. - expert_cross_wall_rate: 0 - wall_bump_rate: 0.0 - img_size: 65 - max_step: 1 - sample_length: 17 # 17 in PLDM | 90 to eval long rollout - n_steps: 91 - n_steps_reduce_factor: 1 - size: 100000 - val_size: 10000 - train: true - repeat_actions: 1 - normalize: true - # new - num_workers: 0 - pin_mem: false - persistent_workers: false + # Eval-specific overrides (base config in eb_jepa/datasets/two_rooms/data_config.yaml) diff --git a/examples/ac_video_jepa/cfgs/train.yaml b/examples/ac_video_jepa/cfgs/train.yaml index 9b6c48e..cf4d493 100644 --- a/examples/ac_video_jepa/cfgs/train.yaml +++ b/examples/ac_video_jepa/cfgs/train.yaml @@ -3,65 +3,25 @@ logging: wandb_group: wandb_sweep: log_every: 10 - exp_suffix: 26-01-14 + exp_suffix: 26-01-15 save_every_n_epochs: 1 - tqdm_silent: true + tqdm_silent: false meta: - quick_debug: true - plan_eval_only_mode: false - light_eval_only_mode: false - ckpt_dir: checkpoints model_folder: load_model: true - load_checkpoint: # e-9.pth.tar enable_plan_eval: true eval_every_itr: -1 light_eval_freq: 50 - start_probing_after: 0 seed: 1 data: env_name: two_rooms - action_noise: 1 - action_angle_noise: 0.2 - action_step_mean: 1.0 - action_step_std: 0.4 - action_lower_bd: 0.2 - action_upper_bd: 1.8 + # Training-specific overrides (base config in eb_jepa/datasets/two_rooms/data_config.yaml) batch_size: 384 - device: cpu - dot_std: 1.3 - border_wall_loc: 5 - fix_wall_batch_k: - fix_wall: false # false to randomize both wall and door loc - fix_door_location: 18 - fix_wall_location: 32 - exclude_wall_train: '' - exclude_door_train: '' - only_wall_val: '' - only_door_val: '' - wall_padding: 20 - door_padding: 10 - wall_width: 3 - door_space: 4 - num_train_layouts: -1 - cross_wall_rate: 0.35 - dup_traj_rate: 0. - expert_cross_wall_rate: 0 - wall_bump_rate: 0.0 - img_size: 65 - max_step: 1 - sample_length: 17 # 17 in PLDM | 90 to eval long rollout - n_steps: 91 - n_steps_reduce_factor: 1 - size: 100000 - val_size: 10000 - train: true - repeat_actions: 1 - normalize: true - # new num_workers: 16 pin_mem: true persistent_workers: true +training: + use_amp: true dtype: bfloat16 model: compile: true @@ -93,3 +53,14 @@ optim: eval: plan_cfg_path: examples/ac_video_jepa/cfgs/planning_mppi.yaml # eb_jepa/planning_mppi.yaml eval_cfg_path: examples/ac_video_jepa/cfgs/eval.yaml + +# --- Parameters used when running with --full-sweep +# By default, only the parameters specified above are used +sweep: + # Sweep grid for AC video JEPA hyperparameter search + param_grid: + model.regularizer.cov_coeff: [8, 12] + model.regularizer.std_coeff: [8, 16] + model.regularizer.sim_coeff_t: [8, 12, 16] + model.regularizer.idm_coeff: [1, 2] + meta.seed: [1, 1000, 10000] diff --git a/examples/ac_video_jepa/eval.py b/examples/ac_video_jepa/eval.py new file mode 100644 index 0000000..f7587f8 --- /dev/null +++ b/examples/ac_video_jepa/eval.py @@ -0,0 +1,95 @@ +""" +Evaluation utilities for action-conditioned Video JEPA. +""" + +import os +from pathlib import Path + +import torch +import yaml + +from eb_jepa.logging import get_logger +from eb_jepa.planning import main_eval, main_unroll_eval + +logger = get_logger(__name__) + + +@torch.no_grad() +def launch_plan_eval( + jepa, + env_creator, + folder, + epoch, + global_step, + suffix="", + num_eval_episodes=10, + loader=None, + prober=None, + plan_cfg=None, +): + """Evaluate the planning capabilities of the trained JEPA model.""" + logger.info(f"Planning eval: epoch={epoch} step={global_step}") + jepa.eval() + folder = Path(folder) + eval_folder = folder / "plan_eval" / f"step-{global_step}{suffix}" + os.makedirs(eval_folder, exist_ok=True) + + if plan_cfg is not None: + plan_cfg_file = eval_folder / "plan_config.yaml" + with open(plan_cfg_file, "w") as f: + yaml.dump(plan_cfg, f) + + eval_results = main_eval( + plan_cfg=plan_cfg, + model=jepa, + env_creator=env_creator, + eval_folder=eval_folder, + num_episodes=num_eval_episodes, + loader=loader, + prober=prober, + ) + logger.info( + f" success_rate={eval_results['success_rate']:.2f} | mean_dist={eval_results['mean_state_dist']:.4f}" + ) + jepa.train() + + return eval_results + + +@torch.no_grad() +def launch_unroll_eval( + jepa, + env_creator, + folder, + epoch, + global_step, + suffix="", + loader=None, + prober=None, + cfg=None, +): + """Evaluate the unrolling (prediction) capabilities of the trained JEPA model.""" + jepa.eval() + logger.info(f"Unroll eval: epoch={epoch} step={global_step}") + folder = Path(folder) + eval_folder = folder / "unroll_eval" / f"step-{global_step}{suffix}" + os.makedirs(eval_folder, exist_ok=True) + eval_results = main_unroll_eval( + jepa, + env_creator, + eval_folder, + loader=loader, + prober=prober, + cfg=cfg, + ) + steps = [0, 1, 2, 3] + mean_values = " | ".join( + [f"t{i}={eval_results[f'val_rollout/mean_mse/{i}']:.2f}" for i in steps] + ) + std_values = " | ".join( + [f"{i}: {eval_results[f'val_rollout/std_mse/{i}']:.2f}" for i in steps] + ) + logger.info(f"Unroll eval - mean_mse: {mean_values} | std_mse: {std_values}") + jepa.train() + + return eval_results diff --git a/examples/ac_video_jepa/launch_sbatch.py b/examples/ac_video_jepa/launch_sbatch.py deleted file mode 100644 index b296621..0000000 --- a/examples/ac_video_jepa/launch_sbatch.py +++ /dev/null @@ -1,382 +0,0 @@ -import argparse -import importlib -import os -import shutil -from pathlib import Path - -import submitit - -from examples.ac_video_jepa.main import get_experiment_folder, load_override_cfg - - -def copy_code_folder(code_folder): - """Copy the code folder to the experiment directory, ignoring unnecessary files.""" - ignore_patterns = [ - "__pycache__", - ".vscode", - ".git", - "core", - "mnist_test_seq.npy", - "uv.lock", - "Makefile", - ] - ignore_paths = [ - "traces", - "docs", - ".pytest_cache", - "logs", - ".venv", - "eb_jepa.egg-info", - ] - - def ignore_func(path, names): - ignore_list = list(ignore_patterns) - for ignore_path in ignore_paths: - if ignore_path in names: - ignore_list.append(ignore_path) - return ignore_list - - if not os.path.exists(code_folder): - shutil.copytree(".", code_folder, ignore=ignore_func) - - -def launch_job(fname: str, **kwargs): - """Launch a single training job with the given config and overrides.""" - cfg = load_override_cfg(fname, kwargs) - folder = get_experiment_folder( - cfg, cfg.data, quick_debug=cfg.meta.get("quick_debug") - ) - os.makedirs(folder, exist_ok=True) - - code_folder = os.path.join(folder, "code") - copy_code_folder(code_folder) - print(f"Changing to code folder: {code_folder}") - os.chdir(code_folder) - executor = submitit.AutoExecutor( - folder=os.path.join(folder, "job_%j"), slurm_max_num_timeout=20 - ) - - executor.update_parameters( - name="AC_JEPA", - slurm_mem_per_gpu="55G", - cpus_per_task=16, - timeout_min=24 * 60, - slurm_partition="learn", - slurm_additional_parameters={ - "nodes": 1, - "ntasks-per-node": 1, - "gpus-per-node": 1, - "qos": "explore", - "account": "fair_amaia_cw_video", - }, - ) - job = executor.submit(run_experiment, cfg, folder) - - print(f"Submitted job {job.job_id}") - print(f"Experiment folder: {folder}") - - return job - - -def run_experiment(cfg, folder=None): - print(f"Current working directory: {os.getcwd()}") - return importlib.import_module("examples.ac_video_jepa.main").main( - cfg=cfg, folder=folder - ) - - -def launch_sweep( - fname: str, param_grid: dict, array_parallelism: int, **base_overrides -): - """ - Launch a parameter sweep using submitit batch submission. - - Args: - config_path: Path to the base config file - param_grid: Dictionary of parameter names to lists of values - **base_overrides: Base configuration overrides to apply to all jobs - """ - from itertools import product - - param_names = list(param_grid.keys()) - param_values_list = list(param_grid.values()) - all_combinations = list(product(*param_values_list)) - - if not all_combinations: - print("No parameter combinations to sweep") - return [] - - base_cfg = load_override_cfg(fname, base_overrides) - common_ckpt_dir = Path(base_cfg.meta.ckpt_dir) - sweep_logs_dir = common_ckpt_dir / "sweep_slurm_logs" - sweep_logs_dir.mkdir(parents=True, exist_ok=True) - - sweep_code_folder = common_ckpt_dir / "code" - copy_code_folder(str(sweep_code_folder)) - print(f"Changing to code folder: {sweep_code_folder}") - os.chdir(sweep_code_folder) - - executor = submitit.AutoExecutor( - folder=str(sweep_logs_dir), slurm_max_num_timeout=20 - ) - - executor.update_parameters( - name="AC_JEPA_sweep", - slurm_mem_per_gpu="55G", - cpus_per_task=16, - timeout_min=24 * 60, - slurm_partition="learn", - slurm_array_parallelism=array_parallelism, - slurm_additional_parameters={ - "nodes": 1, - "ntasks-per-node": 1, - "gpus-per-node": 1, - "qos": "explore", - "account": "fair_amaia_cw_video", - }, - ) - jobs = [] - with executor.batch(): - for i, values in enumerate(all_combinations): - param_overrides = dict(zip(param_names, values)) - final_overrides = {**base_overrides, **param_overrides} - cfg = load_override_cfg(fname, final_overrides) - folder = get_experiment_folder( - cfg, cfg.data, quick_debug=cfg.meta.get("quick_debug") - ) - os.makedirs(folder, exist_ok=True) - print(f"Submitting task {i}: {param_overrides}") - print(f" -> folder: {folder}") - job = executor.submit(run_experiment, cfg, folder) - jobs.append(job) - - print(f"Submitted {len(jobs)} jobs in batch") - print(f"Sweep logs directory: {sweep_logs_dir}") - - return jobs - - -def create_wandb_sweep_config(param_grid: dict, method: str = "grid"): - """ - Create a wandb sweep configuration from a parameter grid. - - Args: - param_grid: Dictionary of parameter names to lists of values - method: Sweep method ("grid", "random", or "bayes") - - Returns: - Dictionary representing the wandb sweep configuration - """ - sweep_config = { - "method": method, - "metric": { - "goal": "maximize", - "name": "success_rate", - }, - "parameters": {}, - } - - for param_name, param_values in param_grid.items(): - if isinstance(param_values, list): - sweep_config["parameters"][param_name] = {"values": param_values} - elif isinstance(param_values, dict): - sweep_config["parameters"][param_name] = param_values - - return sweep_config - - -def launch_wandb_sweep( - fname: str, - param_grid: dict, - method: str = "grid", - array_parallelism: int = 256, - **base_overrides, -): - """ - Launch a wandb sweep using submitit. Each SLURM job = 1 training run. - - Creates a wandb sweep for tracking/visualization, then launches individual - SLURM jobs with deterministic run IDs for proper resuming after requeues. - - Args: - fname: Path to the base config file - param_grid: Dictionary of parameter names to lists of values - method: Sweep method ("grid", "random", or "bayes") - array_parallelism: Number of jobs to run in parallel - **base_overrides: Base configuration overrides to apply to all jobs - """ - from itertools import product - - import wandb - - # Load base config - base_cfg = load_override_cfg(fname, base_overrides) - project_name = "eb-jepa-ac" - - # Create wandb sweep configuration - sweep_config = create_wandb_sweep_config(param_grid, method) - - # Initialize the sweep (creates it on wandb servers) - sweep_id = wandb.sweep(sweep_config, project=project_name) - print(f"Created wandb sweep with ID: {sweep_id}") - print( - f"View sweep at: https://fairwandb.org/{wandb.api.default_entity}/{project_name}/sweeps/{sweep_id}" - ) - - # Generate all parameter combinations (same as regular sweep) - param_names = list(param_grid.keys()) - param_values_list = list(param_grid.values()) - all_combinations = list(product(*param_values_list)) - - if not all_combinations: - print("No parameter combinations to sweep") - return sweep_id, [] - - # Set up directories - common_ckpt_dir = Path(base_cfg.meta.ckpt_dir) - sweep_logs_dir = common_ckpt_dir / "wandb_sweep_slurm_logs" - sweep_logs_dir.mkdir(parents=True, exist_ok=True) - - sweep_code_folder = common_ckpt_dir / "code" - copy_code_folder(str(sweep_code_folder)) - print(f"Changing to code folder: {sweep_code_folder}") - os.chdir(sweep_code_folder) - - # Set up submitit executor - executor = submitit.AutoExecutor( - folder=str(sweep_logs_dir), slurm_max_num_timeout=20 - ) - - executor.update_parameters( - name="AC_JEPA_wandb_sweep", - slurm_mem_per_gpu="55G", - cpus_per_task=16, - timeout_min=24 * 60, - slurm_partition="learn", - slurm_array_parallelism=array_parallelism, - slurm_additional_parameters={ - "nodes": 1, - "ntasks-per-node": 1, - "gpus-per-node": 1, - "qos": "explore", - "account": "fair_amaia_cw_video", - }, - ) - - # Launch one SLURM job per training run (maintaining 1:1 convention) - jobs = [] - with executor.batch(): - for i, values in enumerate(all_combinations): - param_overrides = dict(zip(param_names, values)) - final_overrides = { - **base_overrides, - **param_overrides, - "logging.wandb_sweep": True, - "logging.wandb_sweep_id": sweep_id, # Pass sweep_id as config param - } - cfg = load_override_cfg(fname, final_overrides) - folder = get_experiment_folder( - cfg, cfg.data, quick_debug=cfg.meta.get("quick_debug") - ) - os.makedirs(folder, exist_ok=True) - - print(f"Submitting job {i}: {param_overrides}") - print(f" -> folder: {folder}") - job = executor.submit(run_experiment, cfg, folder) - jobs.append(job) - - print(f"Submitted {len(jobs)} jobs (1 job = 1 training run)") - print(f"Sweep logs directory: {sweep_logs_dir}") - print(f"Sweep ID: {sweep_id}") - - return sweep_id, jobs - - -if __name__ == "__main__": - """ - Single run - python examples/ac_video_jepa/launch_sbatch.py --fname examples/ac_video_jepa/cfgs/train.yaml --optim.lr 0.0005 --model.regularizer.sim_coeff_t 0.75 - Sweep with common directory name (like run_exp.sh) - python examples/ac_video_jepa/launch_sbatch.py --sweep randwall_imp_fixeval_sweep - python examples/ac_video_jepa/launch_sbatch.py --sweep rand_conv3d_bs128_sweep --fname examples/ac_video_jepa/cfgs/train_conv3d.yaml --array_parallelism 1200 - python examples/ac_video_jepa/launch_sbatch.py --sweep rand_imp_bs384AdamW_losssweep --fname examples/ac_video_jepa/cfgs/train.yaml --array_parallelism 1200 --use_wandb_sweep - """ - parser = argparse.ArgumentParser(description="Submitit launcher for AC Video JEPA") - parser.add_argument( - "--fname", - default="examples/ac_video_jepa/cfgs/train.yaml", - help="Path to config file", - ) - parser.add_argument( - "--sweep", - type=str, - help="Name for the sweep (sets meta.ckpt_dir and logging.wandb_group)", - ) - parser.add_argument( - "--array_parallelism", - type=int, - default=256, - help="Number of jobs to run in parallel for the sweep", - ) - parser.add_argument( - "--use_wandb_sweep", - action="store_true", - help="Use wandb sweep instead of submitit sweep", - ) - parser.add_argument( - "--sweep_method", - type=str, - default="grid", - choices=["grid", "random", "bayes"], - help="Wandb sweep method (grid, random, or bayes)", - ) - parser.add_argument("--model.regularizer.cov_coeff", type=float) - parser.add_argument("--model.regularizer.std_coeff", type=float) - parser.add_argument("--model.regularizer.sim_coeff_t", type=float) - parser.add_argument("--model.regularizer.idm_coeff", type=float) - parser.add_argument("--model.regularizer.spatial_as_samples", type=bool) - parser.add_argument("--optim.lr", type=float) - parser.add_argument("--meta.quick_debug", type=bool) - args = parser.parse_args() - overrides = { - k: v - for k, v in vars(args).items() - if v is not None - and k - not in [ - "fname", - "sweep", - "array_parallelism", - "use_wandb_sweep", - "sweep_method", - ] - } - if args.sweep is not None: - overrides["meta.ckpt_dir"] = ( - f"/checkpoint/amaia/video/basileterv/experiment/eb_jepa/{args.sweep}" - ) - overrides["logging.wandb_group"] = args.sweep - param_grid = { - "model.regularizer.cov_coeff": [8, 12], - "model.regularizer.std_coeff": [8, 16], - "model.regularizer.sim_coeff_t": [8, 12, 16], - "model.regularizer.idm_coeff": [1, 2], - "meta.seed": [1, 1000, 10000], - } - if overrides: - print(f"Base overrides: {overrides}") - - if args.use_wandb_sweep: - sweep_id, jobs = launch_wandb_sweep( - args.fname, - param_grid, - method=args.sweep_method, - array_parallelism=args.array_parallelism, - **overrides, - ) - else: - jobs = launch_sweep( - args.fname, param_grid, args.array_parallelism, **overrides - ) - else: - job = launch_job(args.fname, **overrides) diff --git a/examples/ac_video_jepa/main.py b/examples/ac_video_jepa/main.py index 169adc2..7f75cf4 100644 --- a/examples/ac_video_jepa/main.py +++ b/examples/ac_video_jepa/main.py @@ -1,19 +1,18 @@ import copy import os -import random from pathlib import Path from time import time -from typing import Optional -import numpy as np +import fire import torch import torch.nn as nn +import wandb import yaml from omegaconf import OmegaConf from torch.amp import GradScaler, autocast from torch.optim import AdamW +from tqdm import tqdm -import wandb from eb_jepa.architectures import ( ImpalaEncoder, InverseDynamicsModel, @@ -24,190 +23,114 @@ from eb_jepa.jepa import JEPA, JEPAProbe from eb_jepa.logging import get_logger from eb_jepa.losses import SquareLossSeq, VC_IDM_Sim_Regularizer -from eb_jepa.planning import main_eval, main_unroll_eval from eb_jepa.schedulers import CosineWithWarmup -from examples.ac_video_jepa.heads import MLPXYHead +from eb_jepa.state_decoder import MLPXYHead +from eb_jepa.training_utils import ( + get_default_dev_name, + get_exp_name, + get_unified_experiment_dir, + load_checkpoint, + load_config, + log_config, + log_data_info, + log_epoch, + log_model_info, + save_checkpoint, + setup_device, + setup_seed, + setup_wandb, +) +from examples.ac_video_jepa.eval import launch_plan_eval, launch_unroll_eval logger = get_logger(__name__) -def clean_state_dict(state_dict): - return {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} - - -@torch.no_grad() -def launch_plan_eval( - jepa, - env_creator, - folder, - epoch, - global_step, - suffix="", - num_eval_episodes=10, - loader=None, - prober=None, - plan_cfg=None, -): - """Evaluate the planning capabilities of the trained JEPA model.""" - logger.info(f"🎯 Planning eval | epoch={epoch} step={global_step}") - jepa.eval() - eval_folder = folder / "plan_eval" / f"step-{global_step}{suffix}" - os.makedirs(eval_folder, exist_ok=True) - - if plan_cfg is not None: - plan_cfg_file = eval_folder / "plan_config.yaml" - with open(plan_cfg_file, "w") as f: - yaml.dump(plan_cfg, f) - - eval_results = main_eval( - plan_cfg=plan_cfg, - model=jepa, - env_creator=env_creator, - eval_folder=eval_folder, - num_episodes=num_eval_episodes, - loader=loader, - prober=prober, - ) - logger.info( - f" βœ“ success_rate={eval_results['success_rate']:.2f} | mean_dist={eval_results['mean_state_dist']:.4f}" - ) - jepa.train() - - return eval_results - - -@torch.no_grad() -def launch_unroll_eval( - jepa, - env_creator, - folder, - epoch, - global_step, - suffix="", - loader=None, - prober=None, - cfg=None, -): - """Evaluate the unrolling (prediction) capabilities of the trained JEPA model.""" - jepa.eval() - logger.info(f"πŸ“Š Unroll eval | epoch={epoch} step={global_step}") - eval_folder = folder / "unroll_eval" / f"step-{global_step}{suffix}" - os.makedirs(eval_folder, exist_ok=True) - eval_results = main_unroll_eval( - jepa, - env_creator, - eval_folder, - loader=loader, - prober=prober, - cfg=cfg, - ) - steps = [0, 1, 2, 3] - mean_values = " | ".join( - [f"t{i}={eval_results[f'val_rollout/mean_mse/{i}']:.2f}" for i in steps] - ) - std_values = " | ".join( - [f"{i}: {eval_results[f'val_rollout/std_mse/{i}']:.2f}" for i in steps] - ) - logger.info(f"Unroll eval - mean_mse: {mean_values} | std_mse: {std_values}") - jepa.train() - - return eval_results - - -def load_override_cfg(fname: str, kwargs_dict: Optional[dict] = None): - """Load configuration from a YAML file and optionally override with a dictionary.""" - assert fname and os.path.exists(fname), f"Config file not found: {fname}" - with open(fname, "r") as f: - cfg = OmegaConf.create(yaml.safe_load(f)) - print(f"Loaded config from {fname}") - if kwargs_dict: - override_dict = {} - for arg_name, arg_value in kwargs_dict.items(): - keys = arg_name.split(".") - current = override_dict - for key in keys[:-1]: - current = current.setdefault(key, {}) - current[keys[-1]] = arg_value - cfg = OmegaConf.merge(cfg, OmegaConf.create(override_dict)) - return cfg - - -def get_experiment_folder_without_seed(cfg, data_config, quick_debug=False): - """Generate the experiment folder path without seed for wandb aggregation.""" - if cfg.meta.get("model_folder"): - return Path(cfg.meta.model_folder) - exp_name = f"{'deb_' if quick_debug else ''}{cfg.model.encoder_architecture}_encsk{cfg.model.encoder_skip_connections}_prsk{cfg.model.predictor_skip_connections}_cov{cfg.model.regularizer.cov_coeff}_std{cfg.model.regularizer.std_coeff}_simt{cfg.model.regularizer.get('sim_coeff_t')}_idm{cfg.model.regularizer.get('idm_coeff')}_sp{cfg.model.regularizer.spatial_as_samples}_useproj{cfg.model.regularizer.use_proj}_idmproj{cfg.model.regularizer.idm_after_proj}_simtproj{cfg.model.regularizer.sim_t_after_proj}_1stt{cfg.model.regularizer.get('first_t_only')}_roll{cfg.model.train_rollout}_dstc{cfg.model.dstc}_henc{cfg.model.henc}_hpre{cfg.model.hpre}_lr{cfg.optim.lr}_wd{cfg.optim.weight_decay}_bs{data_config.batch_size}_samplen{data_config.sample_length}_size{data_config.size}_cw{data_config.cross_wall_rate}_wb{data_config.wall_bump_rate}_duptraj{data_config.dup_traj_rate}_fixw{data_config.fix_wall}_{cfg.logging['exp_suffix'] if cfg.logging.get('exp_suffix') else ''}" - return cfg.meta.ckpt_dir / Path(exp_name) - - -def get_experiment_folder(cfg, data_config, quick_debug=False): - """Generate the experiment folder path based on configuration.""" - if cfg.meta.get("model_folder"): - return Path(cfg.meta.model_folder) - exp_name_base = get_experiment_folder_without_seed(cfg, data_config, quick_debug) - exp_name = f"{os.path.basename(exp_name_base)}_seed{cfg.meta.seed}" - return cfg.meta.ckpt_dir / Path(exp_name) - - -def main( +def run( fname: str = "examples/ac_video_jepa/cfgs/train.yaml", cfg=None, folder=None, - **kwargs, + **overrides, ): """ Train an action-conditioned Video JEPA model. Args: fname: Path to the YAML config file. - cfg: Pre-loaded config object (optional, overrides fname). + cfg: Pre-loaded config object (optional, overrides config file). folder: Experiment folder path (optional, auto-generated if not provided). - **kwargs: Config overrides in dot notation (e.g., model.henc=64). + **overrides: Config overrides in dot notation (e.g., model.henc=64). """ if cfg is None: - cfg = load_override_cfg(fname, kwargs) - quick_debug = cfg.meta.get("quick_debug", False) + cfg = load_config(fname, overrides if overrides else None) + + # Create experiment directory using unified structure (if not provided) if folder is None: - folder = get_experiment_folder(cfg, cfg.data, quick_debug=quick_debug) - os.makedirs(folder, exist_ok=True) + if cfg.meta.get("model_folder"): + folder = Path(cfg.meta.model_folder) + folder_name = folder.name + exp_name = folder_name.rsplit("_seed", 1)[0] + else: + sweep_name = get_default_dev_name() + exp_name = get_exp_name("ac_video_jepa", cfg) + folder = get_unified_experiment_dir( + example_name="ac_video_jepa", + sweep_name=sweep_name, + exp_name=exp_name, + seed=cfg.meta.seed, + ) + else: + folder = Path(folder) + folder_name = folder.name + exp_name = folder_name.rsplit("_seed", 1)[0] - if quick_debug: - cfg.logging.log_wandb = False - cfg.meta.eval_every_itr = 2 - cfg.meta.light_eval_freq = 2 - cfg.model.compile = False - cfg.data.num_workers = 0 - cfg.data.batch_size = 4 - cfg.logging.tqdm_silent = False - - if cfg.meta.light_eval_only_mode: - cfg.meta.light_eval_freq = 2 - cfg.logging.log_wandb = False - cfg.data.batch_size = 4 - cfg.logging.tqdm_silent = False - - train_jepa = True - train_probe = True - if cfg.meta.get("light_eval_only_mode") or cfg.meta.get("plan_eval_only_mode"): - train_jepa, train_probe, cfg.logging.log_wandb = False, False, False + os.makedirs(folder, exist_ok=True) loader, val_loader, data_config = init_data( env_name=cfg.data.env_name, cfg_data=dict(cfg.data) ) - logger.info( - f"πŸ“¦ Data: {len(loader)} batches Γ— {data_config.batch_size} samples" + + # -- SETUP + setup_device("auto") + setup_seed(cfg.meta.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # -- WANDB + wandb_run = setup_wandb( + project="eb_jepa", + config={ + "example": "ac_video_jepa", + **OmegaConf.to_container(cfg, resolve=True), + }, + run_dir=folder, + run_name=exp_name, + tags=[f"seed_{cfg.meta.seed}", "ac_video_jepa"], + group=cfg.logging.get("wandb_group"), + enabled=cfg.logging.get("log_wandb", False), + sweep_id=cfg.logging.get("wandb_sweep_id"), ) - # Set seed - torch.manual_seed(cfg.meta.seed) - np.random.seed(cfg.meta.seed) - random.seed(cfg.meta.seed) - dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16} - dtype = dtype_map.get(cfg.data.dtype.lower(), torch.float32) - mixed_precision = dtype != torch.float32 + log_data_info( + cfg.data.env_name, + len(loader), + data_config.batch_size, + train_samples=data_config.size, + val_samples=data_config.val_size, + ) - # -- ENV - if cfg.meta.get("enable_plan_eval"): + # Mixed precision setup + dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16} + dtype = dtype_map.get(cfg.training.get("dtype", "float16").lower(), torch.float16) + use_amp = cfg.training.get("use_amp", True) + scaler = GradScaler(device.type, enabled=use_amp) + logger.info(f"Using AMP with {dtype=}" if use_amp else f"AMP disabled") + + # -- ENV (for plan/unroll eval) + enable_eval = cfg.meta.get("enable_plan_eval", False) + env_creator = None + plan_cfg = None + num_eval_episodes = 10 + + if enable_eval: if cfg.meta.eval_every_itr <= 0: cfg.meta.eval_every_itr = len(loader) with open(cfg.eval.plan_cfg_path, "r") as f: @@ -218,13 +141,7 @@ def main( _, _, env_config = init_data( env_name=cfg.data.env_name, cfg_data=dict(eval_cfg_dict.get("data", {})) ) - num_eval_episodes = ( - eval_cfg_dict.get("meta", {}).get("num_eval_episodes", 10) - if not quick_debug - else 1 - ) - if quick_debug: - eval_cfg_dict["env"]["n_allowed_steps"] = 20 + num_eval_episodes = eval_cfg_dict.get("meta", {}).get("num_eval_episodes", 10) def env_creator(): from eb_jepa.datasets.two_rooms.env import DotWall @@ -235,79 +152,16 @@ def env_creator(): **cfg_eval_env, ) - # -- LOGGING + # -- SAVE CONFIG + latest_ckpt_path = folder / "latest.pth.tar" + steps_per_epoch = data_config.size // data_config.batch_size + total_steps = cfg.optim.epochs * steps_per_epoch config_path = folder / "config.yaml" with open(config_path, "w") as f: OmegaConf.save(cfg, config_path) print(f"Saved complete config to {config_path}") - latest_ckpt_path = folder / "latest.pth.tar" - steps_per_epoch = data_config.size // data_config.batch_size - total_steps = cfg.optim.epochs * steps_per_epoch - - logger.info(f"βš™οΈ Config: {cfg}") - - if cfg.logging.get("log_wandb"): - project_name = "eb-jepa-ac" if not quick_debug else "eb-jepa-ac-debug" - wandb_run_id_file = os.path.join(folder, "wandb_run_id.txt") - - # Create a base experiment name without seed for aggregation - exp_name_base = get_experiment_folder_without_seed( - cfg, cfg.data, quick_debug=cfg.meta.get("quick_debug") - ) - exp_name_base_str = os.path.basename(exp_name_base) - - wandb_config = { - "project": project_name, - "dir": folder, - "name": exp_name_base_str, - "tags": [f"seed_{cfg.meta.seed}"], - "config": OmegaConf.to_container(cfg, resolve=True), - } - - if cfg.logging.wandb_group: - wandb_config["group"] = cfg.logging.wandb_group - - if cfg.logging.get("wandb_sweep") and cfg.logging.get("wandb_sweep_id"): - wandb_config["tags"].append(f"sweep_{cfg.logging.wandb_sweep_id}") - logger.info(f"W&B sweep: {cfg.logging.wandb_sweep_id}") - - # Check for existing run to resume - if os.path.exists(wandb_run_id_file): - with open(wandb_run_id_file, "r") as f: - wandb_run_id = f.read().strip() - # Set environment variables to control both sweep association and run resuming - # WANDB_SWEEP_ID associates the run with the sweep - # WANDB_RUN_ID forces the specific run ID for resuming - # WANDB_RESUME enables resume mode - os.environ["WANDB_SWEEP_ID"] = cfg.logging.wandb_sweep_id - os.environ["WANDB_RUN_ID"] = wandb_run_id - os.environ["WANDB_RESUME"] = "allow" - wandb.init(**wandb_config) - logger.info(f"Resuming W&B run {wandb_run_id} (sweep)") - else: - # First run: set WANDB_SWEEP_ID to associate with sweep - os.environ["WANDB_SWEEP_ID"] = cfg.logging.wandb_sweep_id - wandb.init(**wandb_config) - with open(wandb_run_id_file, "w") as f: - f.write(wandb.run.id) - logger.info(f"Created W&B run {wandb.run.id} (sweep)") - else: - if os.path.exists(wandb_run_id_file): - with open(wandb_run_id_file, "r") as f: - wandb_run_id = f.read().strip() - wandb_config.update({"id": wandb_run_id, "resume": "allow"}) - wandb.init(**wandb_config) - logger.info(f"Resuming W&B run {wandb_run_id}") - else: - wandb.init(**wandb_config) - with open(wandb_run_id_file, "w") as f: - f.write(wandb.run.id) - logger.info(f"Created W&B run {wandb.run.id}") - # -- MODEL - device = "cuda" if torch.cuda.is_available() else "cpu" - scaler = GradScaler() if device == "cuda" and mixed_precision else None test_input = torch.rand( ( 1, @@ -340,20 +194,20 @@ def env_creator(): ) else: projector = None - logger.info(f"🧠 Encoder output: {tuple(test_output.shape)}") + logger.info(f"Encoder output: {tuple(test_output.shape)}") idm = InverseDynamicsModel( state_dim=h * w * (projector.out_dim if cfg.model.regularizer.idm_after_proj else f), - hidden_dim=256, # You can adjust this based on your needs - action_dim=2, # Number of action dimensions in your environment + hidden_dim=256, + action_dim=2, ).to(device) regularizer = VC_IDM_Sim_Regularizer( cov_coeff=cfg.model.regularizer.cov_coeff, std_coeff=cfg.model.regularizer.std_coeff, sim_coeff_t=cfg.model.regularizer.sim_coeff_t, idm_coeff=cfg.model.regularizer.get("idm_coeff", 0.1), - idm=idm, # Pass the IDM model reference + idm=idm, first_t_only=cfg.model.regularizer.get("first_t_only"), projector=projector, spatial_as_samples=cfg.model.regularizer.spatial_as_samples, @@ -362,13 +216,13 @@ def env_creator(): ) ploss = SquareLossSeq() jepa = JEPA(encoder, aencoder, predictor, regularizer, ploss).to(device) - logger.info(jepa) + # Log model structure and parameters encoder_params = sum(p.numel() for p in encoder.parameters()) predictor_params = sum(p.numel() for p in predictor.parameters()) - logger.info( - f"πŸ”’ Parameters: encoder={encoder_params:,} | predictor={predictor_params:,}" - ) + log_model_info(jepa, {"encoder": encoder_params, "predictor": predictor_params}) + + log_config(cfg) # -- PROBER xy_head = MLPXYHead( @@ -381,304 +235,247 @@ def env_creator(): hcost=nn.MSELoss(), ) - jepa_optimizer, jepa_scheduler = None, None - if train_jepa: - jepa_optimizer = AdamW( - jepa.parameters(), lr=cfg.optim.lr, - weight_decay=cfg.optim.get("weight_decay", 1e-6), - ) - jepa_scheduler = CosineWithWarmup(jepa_optimizer, total_steps, warmup_ratio=0.1) + jepa_optimizer = AdamW( + jepa.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.get("weight_decay", 1e-6), + ) + jepa_scheduler = CosineWithWarmup(jepa_optimizer, total_steps, warmup_ratio=0.1) - probe_optimizer, probe_scheduler = None, None - if train_probe: - probe_optimizer = AdamW(xy_head.parameters(), lr=1e-3, weight_decay=1e-5) - probe_scheduler = CosineWithWarmup(probe_optimizer, total_steps, warmup_ratio=0.1) + probe_optimizer = AdamW(xy_head.parameters(), lr=1e-3, weight_decay=1e-5) + probe_scheduler = CosineWithWarmup(probe_optimizer, total_steps, warmup_ratio=0.1) # -- LOAD CKPT - start_epoch = 1 + start_epoch = 0 + ckpt_info = {} if cfg.meta.load_model: + checkpoint_path = folder / cfg.meta.get("load_checkpoint", "latest.pth.tar") + ckpt_info = load_checkpoint( + checkpoint_path, jepa, jepa_optimizer, jepa_scheduler, device=device + ) + start_epoch = ckpt_info.get("epoch", 0) + if "xy_head_state_dict" in ckpt_info: + xy_head.load_state_dict(ckpt_info["xy_head_state_dict"]) - def load_checkpoint(): - if cfg.meta.get("load_checkpoint"): - checkpoint_path = folder / cfg.meta.load_checkpoint - else: - checkpoint_path = latest_ckpt_path - if checkpoint_path.exists(): - checkpoint = torch.load(checkpoint_path, weights_only=False) - msg = jepa.load_state_dict( - clean_state_dict(checkpoint["jepa_state_dict"]) - ) - logger.info(f"Loaded JEPA with message: {msg}") - msg = xy_head.load_state_dict( - clean_state_dict(checkpoint["xy_head_state_dict"]) - ) - logger.info(f"Loaded XY head with message: {msg}") - start_epoch = checkpoint["epoch"] - if "jepa_optimizer_state_dict" in checkpoint and jepa_optimizer: - jepa_optimizer.load_state_dict( - checkpoint["jepa_optimizer_state_dict"] - ) - if "jepa_scheduler_state_dict" in checkpoint and jepa_scheduler: - jepa_scheduler.load_state_dict( - checkpoint["jepa_scheduler_state_dict"] - ) - - if "probe_optimizer_state_dict" in checkpoint and probe_optimizer: - probe_optimizer.load_state_dict( - checkpoint["probe_optimizer_state_dict"] - ) - if "probe_scheduler_state_dict" in checkpoint and probe_scheduler: - probe_scheduler.load_state_dict( - checkpoint["probe_scheduler_state_dict"] - ) - logger.info(f"πŸ“‚ Loaded checkpoint: epoch={start_epoch}") - else: - logging.warning(f"Checkpoint not found at {checkpoint_path}") - start_epoch = 1 - return start_epoch - - start_epoch = load_checkpoint() - # to allow for light eval from ckpt that finished training - if cfg.meta.get("light_eval_only_mode"): - start_epoch -= 1 # Compile if torch.cuda.is_available() and cfg.model.compile: - logger.info("⚑ Compiling model with torch.compile") + logger.info("βœ… Compiling model with torch.compile") jepa = torch.compile(jepa) - # -- TRAINING LOOP - if not cfg.meta.get("plan_eval_only_mode"): - # index epochs starting at epoch 1 - for epoch in range(start_epoch, cfg.optim.epochs): - epoch_start_time = time() - for idx, (x, a, loc, _, _) in enumerate(loader): - itr_start_time = time() - if quick_debug or cfg.meta.get("light_eval_only_mode"): - if idx > 3: - break - global_step = (epoch - 1) * len(loader) + idx - - x = x.to(device) - a = a.to(device) - loc = loc.to(device) - total_loss = torch.tensor(0.0, device=device) - - # Calculate JEPA loss if training JEPA - if train_jepa: - if jepa_optimizer: - jepa_optimizer.zero_grad() - with autocast( - enabled=mixed_precision, device_type=device, dtype=dtype - ): - jepa_loss, regl, regl_unweight, regldict, pl = jepa.forwardn( - x, a, nsteps=cfg.model.nsteps - ) - total_loss += jepa_loss - # Mixed precision backward pass - if scaler is not None: - scaler.scale(jepa_loss).backward() - if cfg.optim.get("grad_clip_enc") and cfg.optim.get( - "grad_clip_pred" - ): - scaler.unscale_(jepa_optimizer) - encoder_grad_norm = torch.nn.utils.clip_grad_norm_( - jepa.encoder.parameters(), cfg.optim.grad_clip_enc - ) - predictor_grad_norm = torch.nn.utils.clip_grad_norm_( - jepa.predictor.parameters(), cfg.optim.grad_clip_pred - ) - scaler.step(jepa_optimizer) - scaler.update() - else: - jepa_loss.backward() - if cfg.optim.get("grad_clip_enc") and cfg.optim.get( - "grad_clip_pred" - ): - encoder_grad_norm = torch.nn.utils.clip_grad_norm_( - jepa.encoder.parameters(), cfg.optim.grad_clip_enc - ) - predictor_grad_norm = torch.nn.utils.clip_grad_norm_( - jepa.predictor.parameters(), cfg.optim.grad_clip_pred - ) - jepa_optimizer.step() - if jepa_scheduler: - jepa_scheduler.step() - else: - encoder_grad_norm = None - predictor_grad_norm = None - # Still compute for logging, but detach to avoid gradients - with torch.no_grad(): - jepa_loss, regl, regl_unweight, regldict, pl = jepa.forwardn( - x, a, nsteps=cfg.model.nsteps - ) - - # Calculate probe loss if training probe and after start point - if train_probe and global_step >= cfg.meta.start_probing_after: - if probe_optimizer: - probe_optimizer.zero_grad() - with autocast( - enabled=mixed_precision, device_type=device, dtype=dtype - ): - xy_loss = xy_prober( - observations=x[:, :, :1], - targets=loc[:, :, :1], - ) - xy_loss = loader.dataset.normalizer.unnormalize_mse(xy_loss) - total_loss += xy_loss - - if probe_optimizer: - if scaler is not None: - scaler.scale(xy_loss).backward() - scaler.step(probe_optimizer) - scaler.update() - else: - xy_loss.backward() - probe_optimizer.step() - probe_scheduler.step() - else: - # Still compute for logging, but detach to avoid gradients - with torch.no_grad(): - xy_loss = xy_prober(x[:, :, :1], loc[:, :, :1]) - xy_loss = loader.dataset.normalizer.unnormalize_mse(xy_loss) - itr_time = time() - itr_start_time - if global_step % cfg.logging.log_every == 0: - log_data = { - "train/total_loss": total_loss.item(), - "train/reg_loss": regl.item(), - "train/reg_loss_unweight": regl_unweight.item(), - "train/pred_loss": pl.item(), - "train/probe_loss": xy_loss.item(), - "global_step": global_step, - "epoch": epoch, - "itr_time": itr_time, - } - # Log optimization metrics - if jepa_optimizer is not None: - log_data["optim/jepa_lr"] = jepa_optimizer.param_groups[0]["lr"] - if probe_optimizer is not None: - log_data["optim/probe_lr"] = probe_optimizer.param_groups[0][ - "lr" - ] - if encoder_grad_norm is not None: - log_data["train/encoder_grad_norm"] = encoder_grad_norm.item() - if predictor_grad_norm is not None: - log_data["train/predictor_grad_norm"] = ( - predictor_grad_norm.item() - ) - for loss_name, loss_value in regldict.items(): - log_data[f"train/regl/{loss_name}"] = loss_value - logger.info( - f"[E{epoch:03d}|S{global_step:05d}] loss={total_loss.item():.3f} " - f"reg={regl.item():.3f} pred={pl.item():.4f} probe={xy_loss.item():.3f}" - ) - - if cfg.logging.get("log_wandb"): - wandb.log(log_data, step=global_step) - - # Planning eval - if ( - cfg.meta.get("enable_plan_eval") - and not cfg.meta.get("light_eval_only_mode") - and (global_step + 1) % cfg.meta.eval_every_itr == 0 - and global_step > 0 - ): - eval_results = launch_plan_eval( - jepa, - env_creator, - folder, - epoch, - global_step, - suffix="", - num_eval_episodes=num_eval_episodes, - loader=val_loader, - prober=xy_prober, - plan_cfg=plan_cfg, - ) - - if cfg.logging.get("log_wandb"): - wandb.log(eval_results, step=global_step) - - # Light eval - if ( - global_step + 1 - ) % cfg.meta.light_eval_freq == 0 and global_step > 0: - eval_results = launch_unroll_eval( - jepa, - env_creator, - folder, - epoch, - global_step, - suffix=( - "-light-only" - if cfg.meta.get("light_eval_only_mode") - else "" - ), - loader=val_loader, - prober=xy_prober, - cfg=cfg, - ) - - if cfg.logging.get("log_wandb"): - wandb.log(eval_results, step=global_step) - epoch_time = time() - epoch_start_time - if cfg.logging.get("log_wandb"): - wandb.log( - {"epoch": epoch, "epoch_time": epoch_time}, - step=epoch * len(loader), - ) - if not cfg.meta.get("light_eval_only_mode"): - - def save_checkpoint(): - # Save checkpoint at the end of the epoch if it's time - save_dict = { - "epoch": epoch, - "jepa_state_dict": jepa.state_dict(), - "xy_head_state_dict": xy_head.state_dict(), - } - if jepa_optimizer: - save_dict["jepa_optimizer_state_dict"] = ( - jepa_optimizer.state_dict() - ) - if jepa_scheduler: - save_dict["jepa_scheduler_state_dict"] = ( - jepa_scheduler.state_dict() - ) - if probe_optimizer: - save_dict["probe_optimizer_state_dict"] = ( - probe_optimizer.state_dict() - ) - if probe_scheduler: - save_dict["probe_scheduler_state_dict"] = ( - probe_scheduler.state_dict() - ) - torch.save( - clean_state_dict(save_dict), - latest_ckpt_path, - ) - if epoch % cfg.logging.save_every_n_epochs == 0: - checkpoint_path = folder / f"e-{epoch}.pth.tar" - torch.save(save_dict, checkpoint_path) - print(f"Checkpoint saved at {checkpoint_path}") - - save_checkpoint() - else: - logger.info("🎯 Plan evaluation mode (skipping training)") - global_step = start_epoch * len(loader) - launch_plan_eval( + # -- EVAL ONLY MODE + if cfg.meta.get("eval_only_mode", False): + if not enable_eval: + raise ValueError("eval_only_mode requires enable_plan_eval=True") + logger.info("Running evaluation only (no training)") + eval_results = launch_unroll_eval( jepa, env_creator, folder, start_epoch, - global_step, - suffix=eval_cfg_dict["meta"].get("plan_suffix"), - num_eval_episodes=num_eval_episodes, - loader=val_loader, - prober=xy_prober, - plan_cfg=plan_cfg, + ckpt_info.get("step", 0), + "_eval_only", + val_loader, + xy_prober, + cfg, + ) + eval_results.update( + launch_plan_eval( + jepa, + env_creator, + folder, + start_epoch, + global_step=ckpt_info.get("step", 0), + suffix="_eval_only", + num_eval_episodes=num_eval_episodes, + loader=val_loader, + prober=xy_prober, + plan_cfg=plan_cfg, + ) + ) + logger.info( + f"Evaluation complete. Success rate: {eval_results['success_rate']:.2%}" + ) + return eval_results + + # -- TRAINING LOOP + for epoch in range(start_epoch, cfg.optim.epochs): + epoch_start_time = time() + pbar = tqdm( + enumerate(loader), + total=len(loader), + desc=f"Epoch {epoch}/{cfg.optim.epochs - 1}", + disable=cfg.logging.get("tqdm_silent", False), ) + for idx, (x, a, loc, _, _) in pbar: + itr_start_time = time() + global_step = epoch * len(loader) + idx + + x = x.to(device) + a = a.to(device) + loc = loc.to(device) + total_loss = torch.tensor(0.0, device=device) + + # Calculate JEPA loss + jepa_optimizer.zero_grad() + with autocast(device.type, enabled=use_amp, dtype=dtype): + _, (jepa_loss, regl, regl_unweight, regldict, pl) = jepa.unroll( + x, + a, + nsteps=cfg.model.nsteps, + unroll_mode="autoregressive", + ctxt_window_time=1, + compute_loss=True, + return_all_steps=False, + ) + total_loss += jepa_loss + + # Mixed precision backward pass + scaler.scale(jepa_loss).backward() + if cfg.optim.get("grad_clip_enc") and cfg.optim.get("grad_clip_pred"): + scaler.unscale_(jepa_optimizer) + torch.nn.utils.clip_grad_norm_( + jepa.encoder.parameters(), cfg.optim.grad_clip_enc + ) + torch.nn.utils.clip_grad_norm_( + jepa.predictor.parameters(), cfg.optim.grad_clip_pred + ) + scaler.step(jepa_optimizer) + scaler.update() + jepa_scheduler.step() + + # Calculate probe loss + probe_optimizer.zero_grad() + with autocast(device.type, enabled=use_amp, dtype=dtype): + xy_loss = xy_prober( + observations=x[:, :, :1], + targets=loc[:, :, :1], + ) + xy_loss = loader.dataset.normalizer.unnormalize_mse(xy_loss) + total_loss += xy_loss + + scaler.scale(xy_loss).backward() + scaler.step(probe_optimizer) + scaler.update() + probe_scheduler.step() + + # Update progress bar + pbar.set_postfix( + { + "loss": f"{total_loss.item():.4f}", + "reg": f"{regl.item():.4f}", + "pred": f"{pl.item():.4f}", + } + ) + itr_time = time() - itr_start_time + if global_step % cfg.logging.log_every == 0: + log_data = { + "train/total_loss": total_loss.item(), + "train/reg_loss": regl.item(), + "train/reg_loss_unweight": regl_unweight.item(), + "train/pred_loss": pl.item(), + "train/probe_loss": xy_loss.item(), + "global_step": global_step, + "epoch": epoch, + "itr_time": itr_time, + "optim/jepa_lr": jepa_optimizer.param_groups[0]["lr"], + "optim/probe_lr": probe_optimizer.param_groups[0]["lr"], + } + for loss_name, loss_value in regldict.items(): + log_data[f"train/regl/{loss_name}"] = loss_value + + if cfg.logging.get("log_wandb"): + wandb.log(log_data, step=global_step) + + # Planning eval (only if eval is enabled) + if ( + enable_eval + and (global_step + 1) % cfg.meta.eval_every_itr == 0 + and global_step > 0 + ): + eval_results = launch_plan_eval( + jepa, + env_creator, + folder, + epoch, + global_step, + suffix="", + num_eval_episodes=num_eval_episodes, + loader=val_loader, + prober=xy_prober, + plan_cfg=plan_cfg, + ) -if __name__ == "__main__": - import fire + if cfg.logging.get("log_wandb"): + wandb.log(eval_results, step=global_step) + + # Light eval (only if eval is enabled) + if ( + enable_eval + and (global_step + 1) % cfg.meta.light_eval_freq == 0 + and global_step > 0 + ): + eval_results = launch_unroll_eval( + jepa, + env_creator, + folder, + epoch, + global_step, + suffix="", + loader=val_loader, + prober=xy_prober, + cfg=cfg, + ) - fire.Fire(main) + if cfg.logging.get("log_wandb"): + wandb.log(eval_results, step=global_step) + + epoch_time = time() - epoch_start_time + + # Log epoch summary + log_epoch( + epoch, + { + "loss": total_loss.item(), + "reg": regl.item(), + "pred": pl.item(), + "probe": xy_loss.item(), + }, + total_epochs=cfg.optim.epochs, + elapsed_time=epoch_time, + ) + + if cfg.logging.get("log_wandb"): + wandb.log( + {"epoch": epoch, "epoch_time": epoch_time}, + step=epoch * len(loader), + ) + + # Save checkpoint + save_checkpoint( + latest_ckpt_path, + model=jepa, + optimizer=jepa_optimizer, + scheduler=jepa_scheduler, + epoch=epoch, + step=global_step, + xy_head_state_dict=xy_head.state_dict(), + probe_optimizer_state_dict=probe_optimizer.state_dict(), + probe_scheduler_state_dict=probe_scheduler.state_dict(), + ) + if epoch % cfg.logging.save_every_n_epochs == 0: + save_checkpoint( + folder / f"e-{epoch}.pth.tar", + model=jepa, + optimizer=jepa_optimizer, + scheduler=jepa_scheduler, + epoch=epoch, + step=global_step, + xy_head_state_dict=xy_head.state_dict(), + probe_optimizer_state_dict=probe_optimizer.state_dict(), + probe_scheduler_state_dict=probe_scheduler.state_dict(), + ) + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/examples/ac_video_jepa/run_exp.sh b/examples/ac_video_jepa/run_exp.sh deleted file mode 100755 index 1eebfb8..0000000 --- a/examples/ac_video_jepa/run_exp.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name=AC-JEPA -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=16 -#SBATCH --mem=55G -#SBATCH --time=24:00:00 -#SBATCH --partition=learn -#SBATCH --signal=B:CONT@60 -#SBATCH --requeue -#SBATCH --output=logs/ac_jepa/%A_%a.out -#SBATCH --error=logs/ac_jepa/%A_%a.err -#SBATCH --array=0-71 - -# ============================================================================= -# Action-Conditioned Video JEPA - Hyperparameter Sweep -# ============================================================================= -# This script launches a grid search over regularization coefficients. -# Adjust the arrays below to customize your sweep. -# -# Usage: -# sbatch examples/ac_video_jepa/run_exp.sh -# -# Single run (without SLURM): -# python -m examples.ac_video_jepa.main --model.regularizer.cov_coeff 8 -# ============================================================================= - -# Grid search parameters -COV_COEFFS=(8 12) -STD_COEFFS=(8 16) -SIM_COEFFS=(8 12 16) -IDM_COEFFS=(1 2) -SEEDS=(1 1000 10000) - -# Setup environment -chmod a+x ~/.bashrc -PS1='$ ' -source ~/.bashrc - -# Change to project root -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" -cd "$PROJECT_ROOT" - -# Create logs directory if it doesn't exist -mkdir -p logs/ac_jepa - -# Generate all parameter combinations -combinations=() -for cov in "${COV_COEFFS[@]}"; do - for std in "${STD_COEFFS[@]}"; do - for sim in "${SIM_COEFFS[@]}"; do - for idm in "${IDM_COEFFS[@]}"; do - for seed in "${SEEDS[@]}"; do - combinations+=("($cov, $std, $sim, $idm, $seed)") - done - done - done - done -done - -# Get the combination for this array task -combination="${combinations[$SLURM_ARRAY_TASK_ID]}" -cov=$(echo "$combination" | awk -F '[(), ]+' '{print $2}') -std=$(echo "$combination" | awk -F '[(), ]+' '{print $3}') -sim=$(echo "$combination" | awk -F '[(), ]+' '{print $4}') -idm=$(echo "$combination" | awk -F '[(), ]+' '{print $5}') -seed=$(echo "$combination" | awk -F '[(), ]+' '{print $6}') - -echo "==============================================" -echo "Running AC Video JEPA experiment:" -echo " cov_coeff=$cov" -echo " std_coeff=$std" -echo " sim_coeff_t=$sim" -echo " idm_coeff=$idm" -echo " seed=$seed" -echo " task_id=$SLURM_ARRAY_TASK_ID" -echo "==============================================" - -# Run the experiment -python -m examples.ac_video_jepa.main \ - --model.regularizer.cov_coeff=${cov} \ - --model.regularizer.std_coeff=${std} \ - --model.regularizer.sim_coeff_t=${sim} \ - --model.regularizer.idm_coeff=${idm} \ - --meta.seed=${seed} diff --git a/examples/image_jepa/README.md b/examples/image_jepa/README.md index 3bba8c1..902ce0a 100644 --- a/examples/image_jepa/README.md +++ b/examples/image_jepa/README.md @@ -1,6 +1,6 @@ ## Self-Supervised Representation Learning from Unlabeled Images -This example demonstrates how to train a Joint Embedding Predictive Architecture (JEPA) on unlabeled images. The model learns representations from individual frames of the CIFAR 10 dataset and is evaluated using linear probing for image classification. +This example demonstrates how to train a Joint Embedding Predictive Architecture (JEPA) on unlabeled images. More precisely, methods studied here are JEAs as there is no predictor. The model learns representations from individual frames of the CIFAR 10 dataset and is evaluated using linear probing for image classification. ![Image JEPA Architecture](assets/arch_figure.png) @@ -15,7 +15,7 @@ This example demonstrates how to train a Joint Embedding Predictive Architecture The Image JEPA consists of: - **Encoder**: ResNet18/Transformer backbone that processes individual images -- **Regularizer**: Variance-Covariance (VC) or LeJEPA loss to prevent representation collapse +- **Regularizer**: Variance-Covariance (VC) or SIGReg loss to prevent representation collapse - **Projector**: Learned MLP Projector (loss is computed on the projected subspace) ## Usage @@ -25,102 +25,88 @@ The Image JEPA consists of: #### 1. ResNet + VICReg Loss ```bash -python main.py \ - --model_type resnet \ - --loss_type vicreg \ - --var_loss_weight 1.0 \ - --cov_loss_weight 80.0 \ - --batch_size 256 \ - --epochs 300 +python -m examples.image_jepa.main --fname examples/image_jepa/cfgs/default.yaml ``` -#### 2. ResNet + LE-JEPA (SIGReg) Loss +#### 2. ResNet + SIGReg (SIGReg) Loss ```bash -python main.py \ - --model_type resnet \ - --loss_type bcs \ - --lmbd 10.0 \ - --batch_size 256 \ - --epochs 300 +python -m examples.image_jepa.main --fname examples/image_jepa/cfgs/sigreg.yaml ``` #### 3. Vision Transformer + VICReg Loss ```bash -python main.py \ - --model_type vit_s \ - --patch_size 2 \ - --loss_type vicreg \ - --sim_loss_weight 25.0 \ - --var_loss_weight 25.0 \ - --cov_loss_weight 1.0 \ - --batch_size 256 \ - --epochs 300 +python -m examples.image_jepa.main --fname examples/image_jepa/cfgs/transformers.yaml ``` -For ViT-Base, use `--model_type vit_b` instead of `vit_s`. +For ViT-Base, add `model.type=vit_b` override: -## Results +```bash +python -m examples.image_jepa.main --fname examples/image_jepa/cfgs/transformers.yaml model.type=vit_b +``` -### Comparison: LE-JEPA vs VICReg +#### Custom Overrides -![Hyperparameter Sensitivity Comparison](assets/hyperparam_sensitivity_comparison.png) +You can override any config parameter using dot notation: -| Metric | LE-JEPA (SigReg) | VICReg | -|--------|------------------|--------| -| Best Accuracy | 90.67% | 90.95% | -| Projector Benefit | +2.5% | Variable | -| Stability | High | Lower (sensitive to hyperparams) | -| Best Projector Dims | 2048Γ—128 | 2048x2048 | +```bash +python -m examples.image_jepa.main --fname examples/image_jepa/cfgs/default.yaml optim.epochs=50 data.batch_size=128 +``` -**Conclusion:** Both methods achieve similar peak performance (~90%). LE-JEPA is more stable across hyperparameter choices, while VICReg can match performance but requires more careful tuning. +## Results + +### Comparison: SIGReg and VICReg + +We first compare the sensitivity to regularizer loss coefficients of SIGReg and VICReg. We sweep over these, training on CIFAR-10 with ResNet-18 backbone, trained for 300 epochs. -Results on CIFAR-10 with ResNet-18 backbone, trained for 300 epochs. +![Hyperparameter Sensitivity Comparison](assets/aggregated_perf_image.png) + +| Metric | SIGReg (BCS) | VICReg | +|--------|------------------|--------| +| Best Accuracy | 91.02% | 90.12% | +| Average (non-collapsed) Accuracy | 89.22% | 84.90% | +| Projector Benefit | +3.3 points | +2.9 points | +| Hyperparameters | 1 | 2 | +| Best Projector Dims | 2048Γ—128 | 2048x1024 | -### LE-JEPA (SigREG Loss) Best Configuration +**Finding:** Both methods achieve similar peak performance (~90%). While both methods can fail in certain cases, logical hyperparameter choices give performance in a similar ballpark. Performing the most naΓ―ve search over hyperparameters, SIGReg can be easier to tune due to the use of a single loss hyperparameter. -| Parameter | Value | -|-----------|-------| -| **Best Accuracy** | **90.67%** | -| batch_size | 256 | -| lmbd (Ξ») | 10.0 | -| use_projector | Yes | -| proj_hidden_dim | 2048 | -| proj_output_dim | 128 | +### Impact of regularizations -### Impact of Lambda (Ξ») +With a batch size of 256 and 1024x1024 projector we investigate the impact of regularizations for VICReg and SIGReg: -| Ξ» | Best Acc | -|---|----------| -| 1.0 | 87.23% | -| **10.0** | **90.67%** | -| 100.0 | 82.28% | +| | SIGReg | | VICReg | | +|------|------------|----------|------------|----------| +| Rank | Hyperparameters | Accuracy | Hyperparameters | Accuracy | +| 1 | $\lambda$= 10 |90.88 % | std = 1 cov = 100 | 90.12 % | +| 2 | $\lambda$= 1 | 86.94% | std = 1 cov = 10 | 89.93 % | +| 3 | $\lambda$= 100 | 80.86% | std = 10 cov = 10 | 89.2 % | +| -1 | $\lambda$= 0.1 | 27.20% | std = 100 cov = 100 | 10.00% | -**Finding:** Ξ»=10.0 is optimal. Too low (Ξ»=1) underperforms by ~3.4%, too high (Ξ»=100) underperforms by ~8.4%. +**Finding:** For both methods, performance can vary drastically between different hyperparameter choices. The main failure mode is when losses such as the invariance term (here set with a weight of 1) become insignificant which leads to a fundamentally flawed training. Logical choices offer much more stable performance. -### Impact of Projector -| Configuration | Mean | Max | -|---------------|------|-----| -| **With Projector** | **90.29%** | **90.67%** | -| No Projector | 87.69% | 88.15% | +### Impact of the projector -**Finding:** Using a projector provides **+2.5%** improvement. +Top 5 dimension combinations. For SIGReg we use Ξ»=10.0, batch_size=256 and for VICReg std=1.0,cov=100, batch_size=256: -### Projector Dimensions (proj_hidden_dim Γ— proj_output_dim) +| | SIGReg | | VICReg | | +|------|------------|----------|------------|----------| +| Rank | Dimensions | Accuracy | Dimensions | Accuracy | +| 1 | 2048 Γ— 128 | 91.02% | 2048 Γ— 1024 | 90.12% | +| 2 | 4096 Γ— 1024 | 91.00% | 4096 x 512 | 90.10% | +| 3 | 2048 Γ— 64 | 90.99% | 1024 x 1024 | 90.05% | +| 4 | 512 Γ— 256 | 90.99% | 2048 x 512 | 90.03% | +| 5 | 4096 Γ— 64 | 90.96% | 4096 x 1024 | 90.02% | +| N/A | None | 87.75% | None | 87.27% | -Top 5 dimension combinations (with Ξ»=10.0, batch_size=256): +**Finding:** Large hidden dimensions are beneficial for both methods, although all of the top-performing scenarios all offer similar performance. +VICReg tends to work better with higher dimensional output dimensions, whereas SIGReg works better with small output dimensions. +This difference does not lead to meaningful practical differences in terms of training time or memory. -| Rank | Dimensions | Accuracy | -|------|------------|----------| -| 1 | 2048 Γ— 128 | 90.67% | -| 2 | 1024 Γ— 256 | 90.65% | -| 3 | 512 Γ— 1024 | 90.61% | -| 4 | 512 Γ— 256 | 90.60% | -| 5 | 4096 Γ— 4096 | 90.56% | +Both methods have a similar drop of performance of around 2.5-3 points when not using a projector, highlighting its importance in the method's design. -**Finding:** Larger hidden dimensions (1024-2048) with smaller output dimensions (128-256) work best. The bottleneck effect (compressing representations) appears beneficial. --- @@ -131,4 +117,4 @@ Top 5 dimension combinations (with Ξ»=10.0, batch_size=256): - [Transformer Architecture](https://arxiv.org/abs/1706.03762) - [Vision Transformer Architecture](https://arxiv.org/abs/2010.11929) - [VICReg](https://arxiv.org/abs/2105.04906) -- [LeJEPA](https://arxiv.org/abs/2511.08544) \ No newline at end of file +- [LeJEPA/SIGReg](https://arxiv.org/abs/2511.08544) diff --git a/examples/image_jepa/assets/aggregated_perf_image.png b/examples/image_jepa/assets/aggregated_perf_image.png new file mode 100644 index 0000000..6573abf Binary files /dev/null and b/examples/image_jepa/assets/aggregated_perf_image.png differ diff --git a/examples/image_jepa/assets/hyperparam_sensitivity_comparison.png b/examples/image_jepa/assets/hyperparam_sensitivity_comparison.png deleted file mode 100644 index 0a3755d..0000000 Binary files a/examples/image_jepa/assets/hyperparam_sensitivity_comparison.png and /dev/null differ diff --git a/examples/image_jepa/cfgs/default.yaml b/examples/image_jepa/cfgs/default.yaml new file mode 100644 index 0000000..17b0f51 --- /dev/null +++ b/examples/image_jepa/cfgs/default.yaml @@ -0,0 +1,61 @@ +# Image JEPA (VICReg/BCS) Training Configuration +# Train a self-supervised image representation model on CIFAR-10 + +meta: + seed: 42 + device: auto # auto, cuda, or cpu + +data: + dataset: cifar10 + batch_size: 256 + num_workers: 4 + +model: + # Backbone + type: resnet # resnet, vit_s, vit_b + patch_size: 2 # For ViT only + + # Projector + use_projector: true + proj_hidden_dim: 2048 + proj_output_dim: 2048 + +loss: + type: vicreg # vicreg or bcs + # VICReg loss weights + std_coeff: 1.0 + cov_coeff: 80.0 + # BCS loss weight (only used if type=bcs) + lmbd: 10.0 + +optim: + epochs: 300 + lr: 0.3 + weight_decay: 1.0e-4 + warmup_epochs: 10 + warmup_start_lr: 3.0e-5 + min_lr: 0.0 + +logging: + log_wandb: true + wandb_group: # Set to group runs (e.g., for seed averaging) + log_every: 10 # Log every N epochs + save_every: 50 # Save checkpoint every N epochs + tqdm_silent: false # Disable tqdm progress bars + +training: + use_amp: true # Use automatic mixed precision + dtype: bfloat16 # float16 or bfloat16 + +# --- Parameters used when running with --full-sweep +# By default, only the parameters specified above are used +sweep: + # Default sweep grid for VICReg hyperparameter search + param_grid: + data.batch_size: [256] + model.use_projector: [true, false] + model.proj_hidden_dim: [64, 128, 256, 512, 1024, 2048, 4096] + model.proj_output_dim: [64, 128, 256, 512, 1024, 2048, 4096] + loss.std_coeff: [1.0, 10.0, 100.0] + loss.cov_coeff: [0.1, 1.0, 10.0, 100.0] + meta.seed: [1, 1000, 10000] diff --git a/examples/image_jepa/cfgs/sigreg.yaml b/examples/image_jepa/cfgs/sigreg.yaml new file mode 100644 index 0000000..f5028b1 --- /dev/null +++ b/examples/image_jepa/cfgs/sigreg.yaml @@ -0,0 +1,57 @@ +# Image SIGReg (BCS) Training Configuration +# Train a self-supervised image representation model on CIFAR-10 using BCS loss + +meta: + seed: 42 + device: auto # auto, cuda, or cpu + +data: + dataset: cifar10 + batch_size: 256 + num_workers: 4 + +model: + # Backbone + type: resnet # resnet, vit_s, vit_b + patch_size: 2 # For ViT only + + # Projector + use_projector: true + proj_hidden_dim: 2048 + proj_output_dim: 128 + +loss: + type: bcs # vicreg or bcs + # BCS loss weight + lmbd: 10.0 + +optim: + epochs: 300 + lr: 0.3 + weight_decay: 1.0e-4 + warmup_epochs: 10 + warmup_start_lr: 3.0e-5 + min_lr: 0.0 + +logging: + log_wandb: true + wandb_group: # Set to group runs (e.g., for seed averaging) + log_every: 10 # Log every N epochs + save_every: 50 # Save checkpoint every N epochs + tqdm_silent: false # Disable tqdm progress bars + +training: + use_amp: true # Use automatic mixed precision + dtype: bfloat16 # float16 or bfloat16 + +# --- Parameters used when running with --full-sweep +# By default, only the parameters specified above are used +sweep: + # Sweep grid for SIGReg (BCS loss) hyperparameter search + param_grid: + data.batch_size: [256] + model.use_projector: [true, false] + model.proj_hidden_dim: [64, 128, 256, 512, 1024, 2048, 4096] + model.proj_output_dim: [64, 128, 256, 512, 1024, 2048, 4096] + loss.lmbd: [0.1,1.0, 10.0, 100.0] + meta.seed: [1, 1000, 10000] diff --git a/examples/image_jepa/cfgs/transformers.yaml b/examples/image_jepa/cfgs/transformers.yaml new file mode 100644 index 0000000..9bbf422 --- /dev/null +++ b/examples/image_jepa/cfgs/transformers.yaml @@ -0,0 +1,57 @@ +# Image JEPA (VICReg) Transformers Configuration +# Train a self-supervised image representation model on CIFAR-10 with ViT architectures + +meta: + seed: 42 + device: auto # auto, cuda, or cpu + +data: + dataset: cifar10 + batch_size: 512 + num_workers: 8 + +model: + # Backbone - supports ViT architectures + type: vit_s # resnet, vit_s, vit_b + patch_size: 2 # For ViT only + + # Projector + use_projector: true + proj_hidden_dim: 2048 + proj_output_dim: 2048 + +loss: + type: vicreg # vicreg or bcs + # VICReg loss weights + std_coeff: 25.0 + cov_coeff: 1.0 + +optim: + epochs: 100 + lr: 0.3 + weight_decay: 1.0e-4 + warmup_epochs: 10 + warmup_start_lr: 3.0e-5 + min_lr: 0.0 + +logging: + log_wandb: true + wandb_group: # Set to group runs (e.g., for seed averaging) + log_every: 10 # Log every N epochs + save_every: 50 # Save checkpoint every N epochs + tqdm_silent: false # Disable tqdm progress bars + +training: + use_amp: true # Use automatic mixed precision + dtype: float16 # float16 or bfloat16 + +# --- Parameters used when running with --full-sweep +# By default, only the parameters specified above are used +sweep: + # Sweep grid for ViT architecture experiments + param_grid: + data.batch_size: [512] + optim.epochs: [50, 100, 1000] + model.use_projector: [false, true] + model.type: [resnet, vit_s, vit_b] + meta.seed: [1, 1000, 10000] diff --git a/examples/image_jepa/dataset.py b/examples/image_jepa/dataset.py index bb462a1..9da14c6 100644 --- a/examples/image_jepa/dataset.py +++ b/examples/image_jepa/dataset.py @@ -9,22 +9,22 @@ class RandomResizedCrop: """Random resized crop augmentation.""" - + def __init__(self, size, scale=(0.2, 1.0)): self.size = size self.scale = scale - + def __call__(self, img): return transforms.RandomResizedCrop(self.size, scale=self.scale)(img) class ColorJitter: """Color jitter augmentation.""" - + def __init__(self, brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, prob=0.8): self.transform = transforms.ColorJitter(brightness, contrast, saturation, hue) self.prob = prob - + def __call__(self, img): if torch.rand(1) < self.prob: return self.transform(img) @@ -33,10 +33,10 @@ def __call__(self, img): class Grayscale: """Grayscale augmentation.""" - + def __init__(self, prob=0.2): self.prob = prob - + def __call__(self, img): if torch.rand(1) < self.prob: return transforms.Grayscale(num_output_channels=3)(img) @@ -45,10 +45,10 @@ def __call__(self, img): class Solarization: """Solarization augmentation.""" - + def __init__(self, prob=0.1): self.prob = prob - + def __call__(self, img): if torch.rand(1) < self.prob: img = transforms.functional.solarize(img, threshold=128) @@ -57,10 +57,10 @@ def __call__(self, img): class HorizontalFlip: """Horizontal flip augmentation.""" - + def __init__(self, prob=0.5): self.prob = prob - + def __call__(self, img): if torch.rand(1) < self.prob: return transforms.functional.hflip(img) @@ -69,40 +69,45 @@ def __call__(self, img): def get_train_transforms(): """Get training transforms for self-supervised learning.""" - transform = transforms.Compose([ - RandomResizedCrop(32, scale=(0.2, 1.0)), - ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, prob=0.8), - Grayscale(prob=0.2), - Solarization(prob=0.1), - HorizontalFlip(prob=0.5), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) - ]) - + transform = transforms.Compose( + [ + RandomResizedCrop(32, scale=(0.2, 1.0)), + ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, prob=0.8 + ), + Grayscale(prob=0.2), + Solarization(prob=0.1), + HorizontalFlip(prob=0.5), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + ) + return transform def get_val_transforms(): """Get validation transforms.""" - return transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) - ]) + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + ) class ImageDataset(torch.utils.data.Dataset): """Custom dataset that applies augmentations multiple times to create views.""" - + def __init__(self, dataset, transform, num_crops=2): self.dataset = dataset self.transform = transform self.num_crops = num_crops - + def __len__(self): return len(self.dataset) - + def __getitem__(self, idx): - image, label = self.dataset[idx] + image, label = self.dataset[idx] views = [self.transform(image) for _ in range(self.num_crops)] return views, label - diff --git a/examples/image_jepa/eval.py b/examples/image_jepa/eval.py index bf5bdb8..58aed50 100644 --- a/examples/image_jepa/eval.py +++ b/examples/image_jepa/eval.py @@ -10,11 +10,11 @@ class LinearProbe(nn.Module): """Linear probe classifier for evaluating representations.""" - + def __init__(self, feature_dim, num_classes): super().__init__() self.classifier = nn.Linear(feature_dim, num_classes) - + def forward(self, x): return self.classifier(x) @@ -23,29 +23,28 @@ def evaluate_linear_probe(model, linear_probe, val_loader, device, use_amp=True) """Evaluate linear probe on validation set.""" model.eval() linear_probe.eval() - + total_loss = 0 correct = 0 total = 0 - + with torch.no_grad(): for data, target in val_loader: data = data.to(device, non_blocking=True) target = target.to(device, non_blocking=True) - - with autocast('cuda', enabled=use_amp): + + with autocast("cuda", enabled=use_amp): features, _ = model(data) outputs = linear_probe(features.float()) loss = F.cross_entropy(outputs, target) - + total_loss += loss.item() _, predicted = outputs.max(1) total += target.size(0) correct += predicted.eq(target).sum().item() - + accuracy = 100.0 * correct / total avg_loss = total_loss / len(val_loader) - - return accuracy, avg_loss + return accuracy, avg_loss diff --git a/examples/image_jepa/main.py b/examples/image_jepa/main.py index ab961be..b87ab9a 100644 --- a/examples/image_jepa/main.py +++ b/examples/image_jepa/main.py @@ -3,53 +3,88 @@ This script implements VICReg training on CIFAR-10 dataset using only PyTorch and torchvision. Supports both ResNet and Vision Transformer (ViT) backbones. + +Usage: + # With YAML config: + python -m examples.image_jepa.main --fname examples/image_jepa/cfgs/default.yaml + + # With config + overrides: + python -m examples.image_jepa.main --fname examples/image_jepa/cfgs/default.yaml optim.epochs=50 """ import os import time -import argparse +from pathlib import Path + +import fire import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from torch.utils.data import DataLoader import torchvision -from torchvision.datasets import CIFAR10 +import wandb +from omegaconf import OmegaConf +from torch.amp import GradScaler, autocast from torch.optim.optimizer import required +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 from torchvision.models import VisionTransformer -import wandb from tqdm import tqdm -from torch.amp import autocast, GradScaler -from eb_jepa.losses import VICRegLoss, BCS -from examples.image_jepa.dataset import get_train_transforms, get_val_transforms, ImageDataset +from eb_jepa.logging import get_logger +from eb_jepa.losses import BCS, VICRegLoss +from eb_jepa.training_utils import ( + get_default_dev_name, + get_exp_name, + get_unified_experiment_dir, + load_checkpoint, + load_config, + log_config, + log_data_info, + log_epoch, + log_model_info, + save_checkpoint, + setup_device, + setup_seed, + setup_wandb, +) +from examples.image_jepa.dataset import ( + ImageDataset, + get_train_transforms, + get_val_transforms, +) from examples.image_jepa.eval import LinearProbe, evaluate_linear_probe +logger = get_logger(__name__) + class ResNet18(nn.Module): """ResNet-18 backbone implementation.""" - + def __init__(self): super().__init__() self.backbone = torchvision.models.resnet18() self.backbone.fc = nn.Identity() # Remove final classification layer - self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False) + self.backbone.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=1, padding=2, bias=False + ) self.backbone.maxpool = nn.Identity() self.features_dim = 512 - def forward(self, x): return self.backbone(x) class ImageSSL(nn.Module): """Image Self-Supervised Learning model implementation.""" - - def __init__(self, backbone, features_dim, proj_hidden_dim=2048, proj_output_dim=2048): + + def __init__( + self, backbone, features_dim, proj_hidden_dim=2048, proj_output_dim=2048 + ): super().__init__() self.backbone = backbone self.features_dim = features_dim - + # Projector self.projector = nn.Sequential( nn.Linear(features_dim, proj_hidden_dim), @@ -60,7 +95,7 @@ def __init__(self, backbone, features_dim, proj_hidden_dim=2048, proj_output_dim nn.ReLU(), nn.Linear(proj_hidden_dim, proj_output_dim), ) - + def forward(self, x): features = self.backbone(x) projections = self.projector(features) @@ -69,7 +104,7 @@ def forward(self, x): class LARS(optim.Optimizer): """LARS optimizer implementation.""" - + def __init__( self, params, @@ -142,7 +177,9 @@ def step(self, closure=None): # lars scaling + weight decay part if p.ndim != 1 or not group["exclude_bias_n_norm"]: if p_norm != 0 and g_norm != 0: - lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"]) + lars_lr = p_norm / ( + g_norm + p_norm * weight_decay + group["eps"] + ) lars_lr *= group["eta"] # clip lr @@ -172,324 +209,405 @@ def step(self, closure=None): class WarmupCosineScheduler: """Warmup cosine learning rate scheduler""" - - def __init__(self, optimizer, warmup_epochs, max_epochs, base_lr, min_lr=0.0, warmup_start_lr=3e-5): + + def __init__( + self, + optimizer, + warmup_epochs, + max_epochs, + base_lr, + min_lr=0.0, + warmup_start_lr=3e-5, + ): self.optimizer = optimizer self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs self.base_lr = base_lr self.min_lr = min_lr self.warmup_start_lr = warmup_start_lr - + def step(self, epoch): if epoch < self.warmup_epochs: - lr = self.warmup_start_lr + epoch * (self.base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + lr = self.warmup_start_lr + epoch * ( + self.base_lr - self.warmup_start_lr + ) / (self.warmup_epochs - 1) else: - lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + torch.cos(torch.tensor((epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs) * 3.14159))) - - for param_group in self.optimizer.param_groups: - param_group['lr'] = lr - + lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * ( + 1 + + torch.cos( + torch.tensor( + (epoch - self.warmup_epochs) + / (self.max_epochs - self.warmup_epochs) + * 3.14159 + ) + ) + ) -def train_epoch(model, train_loader, optimizer, scheduler, linear_probe, scaler, device, epoch, loss_fn, use_amp=True): + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + +def train_epoch( + model, + train_loader, + optimizer, + scheduler, + linear_probe, + scaler, + device, + epoch, + loss_fn, + use_amp=True, + dtype=torch.float16, + tqdm_silent=False, +): """Train for one epoch.""" model.train() linear_probe.train() - + # Dynamic loss accumulator loss_totals = {} total_linear_loss = 0 linear_correct = 0 linear_total = 0 - - pbar = tqdm(train_loader, desc=f'Epoch {epoch}') + + pbar = tqdm(train_loader, desc=f"Epoch {epoch}", disable=tqdm_silent) for batch_idx, (views, target) in enumerate(pbar): - view1, view2 = views[0].to(device, non_blocking=True), views[1].to(device, non_blocking=True) + view1, view2 = views[0].to(device, non_blocking=True), views[1].to( + device, non_blocking=True + ) target = target.to(device, non_blocking=True) - with autocast('cuda', enabled=use_amp): + with autocast(device.type, enabled=use_amp, dtype=dtype): features, z1 = model(view1) - _, z2 = model(view2) + _, z2 = model(view2) loss_dict = loss_fn(z1, z2) loss = loss_dict["loss"] - + with torch.no_grad(): features_frozen = features.detach().float() - + linear_outputs = linear_probe(features_frozen) linear_loss = F.cross_entropy(linear_outputs, target) - + _, predicted = linear_outputs.max(1) linear_correct_batch = predicted.eq(target).sum().item() - + total_loss_batch = loss + linear_loss - + optimizer.zero_grad() scaler.scale(total_loss_batch).backward() scaler.step(optimizer) scaler.update() - + # Update metrics dynamically based on loss_dict keys for key, value in loss_dict.items(): if key not in loss_totals: loss_totals[key] = 0 loss_totals[key] += value.item() total_linear_loss += linear_loss.item() - + # Update linear probe accuracy (pre-computed under autocast) linear_total += target.size(0) linear_correct += linear_correct_batch - + # Update progress bar - pbar.set_postfix({ - 'Loss': f'{loss.item():.4f}', - 'Linear': f'{linear_loss.item():.4f}', - 'Acc': f'{100.*linear_correct/linear_total:.2f}%' - }) - + pbar.set_postfix( + { + "Loss": f"{loss.item():.4f}", + "Linear": f"{linear_loss.item():.4f}", + "Acc": f"{100.*linear_correct/linear_total:.2f}%", + } + ) + # Update learning rate scheduler.step(epoch) - + # Build return dict dynamically num_batches = len(train_loader) metrics = {key: total / num_batches for key, total in loss_totals.items()} - metrics['linear_loss'] = total_linear_loss / num_batches - metrics['linear_acc'] = 100.0 * linear_correct / linear_total - + metrics["linear_loss"] = total_linear_loss / num_batches + metrics["linear_acc"] = 100.0 * linear_correct / linear_total + return metrics -def create_base_parser(description='Image SSL Training on CIFAR-10'): - """Create base argument parser with common arguments. - - This can be extended by other scripts to add their own arguments. +def run( + fname: str = "examples/image_jepa/cfgs/default.yaml", + cfg=None, + folder=None, + **overrides, +): + """ + Train an Image JEPA (VICReg/BCS) model on CIFAR-10. + + Args: + fname: Path to YAML config file + cfg: Pre-loaded config object (optional, overrides config file) + folder: Experiment folder path (optional, auto-generated if not provided) + **overrides: Config overrides in dot notation (e.g., optim.epochs=50) """ - parser = argparse.ArgumentParser(description=description) - - # Training parameters - parser.add_argument('--batch_size', type=int, default=256, help='Batch size for training') - parser.add_argument('--epochs', type=int, default=300, help='Number of training epochs') - parser.add_argument('--learning_rate', type=float, default=0.3, help='Learning rate') - parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay') - parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs') - parser.add_argument('--warmup_start_lr', type=float, default=3e-5, help='Starting learning rate for warmup') - - # Model parameters - parser.add_argument('--model_type', type=str, choices=['resnet', 'vit_s', 'vit_b'], default='resnet', help='Type of encoder') - parser.add_argument('--patch_size', type=int, default=2, help='Patch size for ViT') - parser.add_argument('--proj_hidden_dim', type=int, default=2048, help='Projector hidden dimension') - parser.add_argument('--proj_output_dim', type=int, default=2048, help='Projector output dimension') - parser.add_argument('--use_projector', type=int, default=1, help='Whether to use projector (default: True)') - - # Data parameters - parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers') - parser.add_argument('--data_dir', type=str, default='./datasets', help='Directory to store datasets') - - # Logging parameters - parser.add_argument('--project_name', type=str, default='eb-jepa-image-ssl', help='Wandb project name') - parser.add_argument('--log_interval', type=int, default=10, help='Logging interval in epochs') - parser.add_argument('--save_interval', type=int, default=50, help='Checkpoint saving interval in epochs') - - # Other parameters - parser.add_argument('--seed', type=int, default=42, help='Random seed') - parser.add_argument('--device', type=str, default='auto', help='Device to use (auto, cuda, cpu)') - parser.add_argument('--no_amp', action='store_true', help='Disable mixed precision training (enabled by default)') - - return parser - - -def parse_args(): - """Parse command line arguments.""" - parser = create_base_parser() - - # Loss function selection - parser.add_argument('--loss_type', type=str, choices=['vicreg', 'bcs'], default='vicreg', help='Loss function type') - - # VICReg-specific loss weights - parser.add_argument('--var_loss_weight', type=float, default=1.0, help='Variance loss weight (VICReg)') - parser.add_argument('--cov_loss_weight', type=float, default=80.0, help='Covariance loss weight (VICReg)') - - # BCS-specific loss weight - parser.add_argument('--lmbd', type=float, default=10.0, help='BCS loss weight (LE-JEPA)') - - return parser.parse_args() - - -def main(): - """Main training function.""" - args = parse_args() - - # Print all hyperparameters for run identification in logs - print("=" * 60) - print("Run Configuration:") - print("=" * 60) - for key, value in sorted(vars(args).items()): - print(f" {key}={value}") - print("=" * 60) - - # Set device - if args.device == 'auto': - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Load config + if cfg is None: + cfg = load_config(fname, overrides if overrides else None) + + # Setup using shared utilities + device = setup_device(cfg.meta.device) + setup_seed(cfg.meta.seed) + + # Create experiment directory using unified structure (if not provided) + if folder is None: + if cfg.meta.get("model_folder"): + exp_dir = Path(cfg.meta.model_folder) + folder_name = exp_dir.name + exp_name = folder_name.rsplit("_seed", 1)[0] + else: + sweep_name = get_default_dev_name() + exp_name = get_exp_name("image_jepa", cfg) + exp_dir = get_unified_experiment_dir( + example_name="image_jepa", + sweep_name=sweep_name, + exp_name=exp_name, + seed=cfg.meta.seed, + ) else: - device = torch.device(args.device) - print(f'Using device: {device}') - - # Set random seed - torch.manual_seed(args.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(args.seed) - - wandb.init( - project=args.project_name, - config=vars(args), - name=f'{args.model_type}-{args.loss_type}-{args.seed}' + exp_dir = Path(folder) + exp_dir.mkdir(parents=True, exist_ok=True) + # Extract exp_name from folder name by removing _seed{seed} suffix + folder_name = exp_dir.name # e.g., "resnet_vicreg_seed1" + exp_name = folder_name.rsplit("_seed", 1)[0] # e.g., "resnet_vicreg" + + wandb_run = setup_wandb( + project="eb_jepa", + config={"example": "image_jepa", **OmegaConf.to_container(cfg, resolve=True)}, + run_dir=exp_dir, + run_name=exp_name, + tags=["image_jepa", f"seed_{cfg.meta.seed}"], + group=cfg.logging.get("wandb_group"), + enabled=cfg.logging.log_wandb, + sweep_id=cfg.logging.get("wandb_sweep_id"), ) - - print("Loading CIFAR-10 dataset...") + + logger.info("Loading CIFAR-10 dataset...") transform = get_train_transforms() - + + # Use EBJEPA_DSETS environment variable if set, otherwise fall back to config + data_dir = os.environ.get("EBJEPA_DSETS") + logger.info(f"Using data directory: {data_dir}") + base_train_dataset = CIFAR10( - root=args.data_dir, - train=True, - download=True, - transform=None + root=data_dir, train=True, download=True, transform=None ) - + train_dataset = ImageDataset(base_train_dataset, transform, num_crops=2) - + val_dataset = CIFAR10( - root=args.data_dir, - train=False, - download=True, - transform=get_val_transforms() + root=data_dir, train=False, download=True, transform=get_val_transforms() ) - + train_loader = DataLoader( train_dataset, - batch_size=args.batch_size, + batch_size=cfg.data.batch_size, shuffle=True, - num_workers=args.num_workers, + num_workers=cfg.data.num_workers, pin_memory=True, - drop_last=True # Avoid small batches that cause BatchNorm issues + drop_last=True, # Avoid small batches that cause BatchNorm issues ) - + val_loader = DataLoader( val_dataset, - batch_size=args.batch_size, + batch_size=cfg.data.batch_size, shuffle=False, - num_workers=args.num_workers, - pin_memory=True + num_workers=cfg.data.num_workers, + pin_memory=True, ) - + + log_data_info( + "CIFAR-10", + len(train_loader), + cfg.data.batch_size, + train_samples=len(train_dataset), + val_samples=len(val_dataset), + ) + # Initialize model - print("Initializing model...") - if args.model_type == 'resnet': + logger.info("Initializing model...") + if cfg.model.type == "resnet": backbone = ResNet18() features_dim = backbone.features_dim - elif args.model_type == 'vit_s': + elif cfg.model.type == "vit_s": features_dim = 384 - model_kwargs = dict(image_size=32, patch_size=8, hidden_dim=features_dim, num_layers=12, num_heads=6, mlp_dim=4*features_dim) + model_kwargs = dict( + image_size=32, + patch_size=8, + hidden_dim=features_dim, + num_layers=12, + num_heads=6, + mlp_dim=4 * features_dim, + ) backbone = VisionTransformer(**model_kwargs) backbone.heads = nn.Identity() - elif args.model_type == 'vit_b': + elif cfg.model.type == "vit_b": features_dim = 768 - model_kwargs = dict(image_size=32, patch_size=8, hidden_dim=features_dim, num_layers=12, num_heads=12, mlp_dim=4*features_dim) + model_kwargs = dict( + image_size=32, + patch_size=8, + hidden_dim=features_dim, + num_layers=12, + num_heads=12, + mlp_dim=4 * features_dim, + ) backbone = VisionTransformer(**model_kwargs) backbone.heads = nn.Identity() - model = ImageSSL(backbone, features_dim=features_dim, proj_hidden_dim=args.proj_hidden_dim, proj_output_dim=args.proj_output_dim) + model = ImageSSL( + backbone, + features_dim=features_dim, + proj_hidden_dim=cfg.model.proj_hidden_dim, + proj_output_dim=cfg.model.proj_output_dim, + ) - if not args.use_projector: + if not cfg.model.use_projector: model.projector = nn.Identity() - + model = model.to(device) - + + # Log model structure and parameters + encoder_params = sum(p.numel() for p in backbone.parameters()) + projector_params = ( + sum(p.numel() for p in model.projector.parameters()) + if cfg.model.use_projector + else 0 + ) + log_model_info(model, {"encoder": encoder_params, "projector": projector_params}) + + # Log configuration + log_config(cfg) + # Initialize linear probe linear_probe = LinearProbe(feature_dim=features_dim, num_classes=10).to(device) - - # Initialize mixed precision scaler - scaler = GradScaler('cuda') - - + + # Mixed precision setup + dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16} + use_amp = cfg.training.get("use_amp", True) + dtype = dtype_map.get(cfg.training.get("dtype", "float16").lower(), torch.float16) + scaler = GradScaler(device.type, enabled=use_amp) + logger.info(f"Using AMP with {dtype=}" if use_amp else f"AMP disabled") + optimizer = LARS( [ - {'params': model.parameters(), 'lr': 0.3}, - {'params': linear_probe.parameters(), 'lr': 0.1} # Initialize linear probe parameters + {"params": model.parameters(), "lr": cfg.optim.lr}, + {"params": linear_probe.parameters(), "lr": 0.1}, # Linear probe parameters ], - weight_decay=1e-4, + weight_decay=cfg.optim.weight_decay, eta=0.02, clip_lr=True, exclude_bias_n_norm=True, - momentum=0.9 + momentum=0.9, ) - + scheduler = WarmupCosineScheduler( optimizer, - warmup_epochs=args.warmup_epochs, - max_epochs=args.epochs, - base_lr=args.learning_rate, - min_lr=0.0, - warmup_start_lr=args.warmup_start_lr + warmup_epochs=cfg.optim.warmup_epochs, + max_epochs=cfg.optim.epochs, + base_lr=cfg.optim.lr, + min_lr=cfg.optim.min_lr, + warmup_start_lr=cfg.optim.warmup_start_lr, ) - + # Initialize loss function - if args.loss_type == 'vicreg': - loss_fn = VICRegLoss( - var_loss_weight=args.var_loss_weight, - cov_loss_weight=args.cov_loss_weight - ) - elif args.loss_type == 'bcs': - loss_fn = BCS(lmbd=args.lmbd) - + if cfg.loss.type == "vicreg": + loss_fn = VICRegLoss(std_coeff=cfg.loss.std_coeff, cov_coeff=cfg.loss.cov_coeff) + elif cfg.loss.type == "bcs": + loss_fn = BCS(lmbd=cfg.loss.lmbd) + + # Load checkpoint if requested + start_epoch = 0 + if cfg.meta.get("load_model"): + ckpt_path = exp_dir / cfg.meta.get("load_checkpoint", "latest.pth.tar") + ckpt_info = load_checkpoint(ckpt_path, model, optimizer, device=device) + start_epoch = ckpt_info.get("epoch", 0) + if "linear_probe_state_dict" in ckpt_info: + linear_probe.load_state_dict(ckpt_info["linear_probe_state_dict"]) + # Training loop - print("Starting training...") + logger.info(f"Starting training for {cfg.optim.epochs} epochs...") start_time = time.time() - - for epoch in range(args.epochs): + use_amp = cfg.training.get("use_amp", True) + tqdm_silent = cfg.logging.get("tqdm_silent", False) + + for epoch in range(start_epoch, cfg.optim.epochs): # Train - train_metrics = train_epoch(model, train_loader, optimizer, scheduler, linear_probe, scaler, device, epoch, - loss_fn, not args.no_amp) - + train_metrics = train_epoch( + model, + train_loader, + optimizer, + scheduler, + linear_probe, + scaler, + device, + epoch, + loss_fn, + use_amp, + dtype, + tqdm_silent, + ) + # Evaluate linear probe on validation set - val_acc, val_loss = evaluate_linear_probe(model, linear_probe, val_loader, device, not args.no_amp) - + val_acc, val_loss = evaluate_linear_probe( + model, linear_probe, val_loader, device, use_amp + ) + # Log metrics - dynamically add train_ prefix to all train_metrics keys - log_dict = {'epoch': epoch} + log_dict = {"epoch": epoch} for key, value in train_metrics.items(): - log_dict[f'train_{key}'] = value - log_dict['val_loss'] = val_loss - log_dict['val_acc'] = val_acc - log_dict['learning_rate'] = optimizer.param_groups[0]['lr'] - - wandb.log(log_dict) - - # Print progress - if epoch % args.log_interval == 0: + log_dict[f"train_{key}"] = value + log_dict["val_loss"] = val_loss + log_dict["val_acc"] = val_acc + log_dict["learning_rate"] = optimizer.param_groups[0]["lr"] + + if wandb_run: + wandb.log(log_dict) + + # Log progress + if epoch % cfg.logging.log_every == 0: elapsed = time.time() - start_time - metrics_str = ' | '.join(f'{k}: {v:.4f}' for k, v in train_metrics.items()) - print(f'Epoch {epoch:4d} | {metrics_str} | ' - f'Linear Val: {val_acc:.2f}% | ' - f'LR: {optimizer.param_groups[0]["lr"]:.6f} | ' - f'Time: {elapsed:.1f}s') - + log_epoch( + epoch, + { + "loss": train_metrics["loss"], + "val_acc": val_acc, + "lr": optimizer.param_groups[0]["lr"], + }, + total_epochs=cfg.optim.epochs, + elapsed_time=elapsed, + ) + # Save checkpoint - if epoch % args.save_interval == 0: - checkpoint = { - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'linear_probe_state_dict': linear_probe.state_dict(), - 'scaler_state_dict': scaler.state_dict(), - 'scheduler_state_dict': scheduler, - 'loss': train_metrics['loss'], - 'linear_val_acc': val_acc - } - os.makedirs('examples/image_jepa/trained_models/', exist_ok=True) - torch.save(checkpoint, f'examples/image_jepa/trained_models/checkpoint_epoch_{epoch}.pth') - - print("Training completed!") - wandb.finish() + save_checkpoint( + exp_dir / "latest.pth.tar", + model=model, + optimizer=optimizer, + epoch=epoch, + scaler=scaler, + linear_probe_state_dict=linear_probe.state_dict(), + linear_val_acc=val_acc, + ) + if epoch % cfg.logging.save_every == 0 and epoch > 0: + save_checkpoint( + exp_dir / f"epoch_{epoch}.pth.tar", + model=model, + optimizer=optimizer, + epoch=epoch, + scaler=scaler, + linear_probe_state_dict=linear_probe.state_dict(), + linear_val_acc=val_acc, + ) + + logger.info("Training completed!") + if wandb_run: + wandb.finish() if __name__ == "__main__": - main() + fire.Fire(run) diff --git a/examples/image_jepa/run_exp.sh b/examples/image_jepa/run_exp.sh deleted file mode 100644 index dbc02fa..0000000 --- a/examples/image_jepa/run_exp.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name=VICReg-GridSearch -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=12 -#SBATCH --time=12:00:00 -#SBATCH --partition=learnfair -#SBATCH --signal=B:CONT@60 -#SBATCH --requeue -#SBATCH --output=/checkpoint/amirbar/experiments/eb_jepa/logs/%A_%a.out -#SBATCH --error=/checkpoint/amirbar/experiments/eb_jepa/logs/%A_%a.err -#SBATCH --array=0-195 - -# Grid search parameters -EPOCHS=(300) # number of epochs -BATCH_SIZES=(256 512) # batch size -USE_PROJECTOR=(0 1) # whether to use projector or not -proj_hidden_dim=(64 128 256 512 1024 2048 4096) -proj_output_dim=(64 128 256 512 1024 2048 4096) -var_loss_weight=(1.0 10.0 100.0) -cov_loss_weight=(1.0 10.0 100.0) - -chmod a+x ~/.bashrc -PS1='$ ' -source ~/.bashrc -cd "/private/home/amirbar/projects/eb_jepa_release" - -# Generate all combinations -combinations=() -for bs in "${BATCH_SIZES[@]}"; do - for epochs in "${EPOCHS[@]}"; do - for use_proj in "${USE_PROJECTOR[@]}"; do - for proj_hidden_dim in "${proj_hidden_dim[@]}"; do - for proj_output_dim in "${proj_output_dim[@]}"; do - for var_loss_weight in "${var_loss_weight[@]}"; do - for cov_loss_weight in "${cov_loss_weight[@]}"; do - combinations+=("($bs, $epochs, $use_proj, $proj_hidden_dim, $proj_output_dim, $var_loss_weight, $cov_loss_weight)") - done - done - done - done - done - done -done - -# Get the combination for this array task -combination="${combinations[$SLURM_ARRAY_TASK_ID]}" -bs=$(echo "$combination" | awk -F '[(), ]+' '{print $2}') -epochs=$(echo "$combination" | awk -F '[(), ]+' '{print $3}') -use_proj=$(echo "$combination" | awk -F '[(), ]+' '{print $4}') -proj_hidden_dim=$(echo "$combination" | awk -F '[(), ]+' '{print $5}') -proj_output_dim=$(echo "$combination" | awk -F '[(), ]+' '{print $6}') -var_loss_weight=$(echo "$combination" | awk -F '[(), ]+' '{print $7}') -cov_loss_weight=$(echo "$combination" | awk -F '[(), ]+' '{print $8}') - -echo "Running VICReg grid search experiment:" -echo " batch_size=$bs" -echo " epochs=$epochs" -echo " use_projector=$use_proj" -echo " proj_hidden_dim=$proj_hidden_dim" -echo " proj_output_dim=$proj_output_dim" - -/private/home/amirbar/projects/eb_jepa_release/.venv/bin/python examples/image_jepa/main.py \ - --batch_size=${bs} \ - --epochs=${epochs} \ - --use_projector=${use_proj} \ - --proj_hidden_dim=${proj_hidden_dim} \ - --proj_output_dim=${proj_output_dim} \ - --project_name="vicreg-gridsearch" \ - --loss_type="vicreg" \ - --model_type="resnet" \ - --var_loss_weight=${var_loss_weight} \ - --cov_loss_weight=${cov_loss_weight} diff --git a/examples/image_jepa/run_exp_lejepa.sh b/examples/image_jepa/run_exp_lejepa.sh deleted file mode 100644 index 5072d40..0000000 --- a/examples/image_jepa/run_exp_lejepa.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name=VICReg-GridSearch -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=12 -#SBATCH --time=12:00:00 -#SBATCH --partition=dev -#SBATCH --signal=B:CONT@60 -#SBATCH --requeue -#SBATCH --output=/checkpoint/amirbar/experiments/eb_jepa/logs/%A_%a.out -#SBATCH --error=/checkpoint/amirbar/experiments/eb_jepa/logs/%A_%a.err -#SBATCH --array=1-587 - -# Grid search parameters -LMBD=(10.0 1.0 100.0) # bcs loss weight -EPOCHS=(300) # number of epochs -BATCH_SIZES=(256 512) # batch size -USE_PROJECTOR=(0 1) # whether to use projector or not -proj_hidden_dim=(64 128 256 512 1024 2048 4096) -proj_output_dim=(64 128 256 512 1024 2048 4096) - -chmod a+x ~/.bashrc -PS1='$ ' -source ~/.bashrc -cd "/private/home/amirbar/projects/eb_jepa_release" - -# Generate all combinations -combinations=() -for bs in "${BATCH_SIZES[@]}"; do - for epochs in "${EPOCHS[@]}"; do - for lmbd in "${LMBD[@]}"; do - for use_proj in "${USE_PROJECTOR[@]}"; do - for proj_hidden_dim in "${proj_hidden_dim[@]}"; do - for proj_output_dim in "${proj_output_dim[@]}"; do - combinations+=("($bs, $epochs, $lmbd, $use_proj, $proj_hidden_dim, $proj_output_dim)") - done - done - done - done - done -done - -# Get the combination for this array task -combination="${combinations[$SLURM_ARRAY_TASK_ID]}" -bs=$(echo "$combination" | awk -F '[(), ]+' '{print $2}') -epochs=$(echo "$combination" | awk -F '[(), ]+' '{print $3}') -lmbd=$(echo "$combination" | awk -F '[(), ]+' '{print $4}') -use_proj=$(echo "$combination" | awk -F '[(), ]+' '{print $5}') -proj_hidden_dim=$(echo "$combination" | awk -F '[(), ]+' '{print $6}') -proj_output_dim=$(echo "$combination" | awk -F '[(), ]+' '{print $7}') - -echo "Running VICReg grid search experiment:" -echo " batch_size=$bs" -echo " epochs=$epochs" -echo " lmbd=$lmbd" -echo " use_projector=$use_proj" -echo " proj_hidden_dim=$proj_hidden_dim" -echo " proj_output_dim=$proj_output_dim" - -/private/home/amirbar/projects/eb_jepa_release/.venv/bin/python examples/image_jepa/main.py \ - --batch_size=${bs} \ - --epochs=${epochs} \ - --lmbd=${lmbd} \ - --use_projector=${use_proj} \ - --proj_hidden_dim=${proj_hidden_dim} \ - --proj_output_dim=${proj_output_dim} \ - --project_name="lejepa-gridsearch" \ - --loss_type="bcs" \ - --model_type="resnet" - \ No newline at end of file diff --git a/examples/image_jepa/run_exp_transformers.sh b/examples/image_jepa/run_exp_transformers.sh deleted file mode 100644 index c123c84..0000000 --- a/examples/image_jepa/run_exp_transformers.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name=VICReg-GridSearch -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=12 -#SBATCH --time=4:00:00 -#SBATCH --account fair_amaia_cw_video -#SBATCH --qos dev -#SBATCH --signal=B:CONT@60 -#SBATCH --requeue -#SBATCH --output=/checkpoint/amirbar/experiments/eb_jepa/logs/%A_%a.out -#SBATCH --error=/checkpoint/amirbar/experiments/eb_jepa/logs/%A_%a.err -#SBATCH --array=0-17 - -# Grid search parameters -BATCH_SIZES=(512) # batch size -EPOCHS=(50 100 1000) # number of epochs -USE_PROJECTOR=(1 0) # whether to use projector or not (1=true, 0=false) -MODEL_TYPE=(resnet vit_s vit_b) -PATCH_SIZE=(2) - -# Generate all combinations -combinations=() -for bs in "${BATCH_SIZES[@]}"; do - for epochs in "${EPOCHS[@]}"; do - for use_proj in "${USE_PROJECTOR[@]}"; do - combinations+=("($bs, $epochs, $use_proj)") - done - done -done - -# Get the combination for this array task -combination="${combinations[$SLURM_ARRAY_TASK_ID]}" -bs=$(echo "$combination" | awk -F '[(), ]+' '{print $2}') -epochs=$(echo "$combination" | awk -F '[(), ]+' '{print $3}') -use_proj=$(echo "$combination" | awk -F '[(), ]+' '{print $4}') - -echo "Running VICReg grid search experiment:" -echo " batch_size=$bs" -echo " epochs=$epochs" -echo " use_projector=$use_proj" - -# Create run name for this specific configuration -uv run python -m examples.image_jepa.main \ - --model_type=${MODEL_TYPE} \ - --patch_size=${PATCH_SIZE} \ - --batch_size=${bs} \ - --num_workers=8 \ - --epochs=${epochs} \ - --use_projector=${use_proj} \ - --project_name="vicreg-gridsearch" \ - --data_dir /checkpoint/amaia/video/davidfan/data/CIFAR10 \ - --use_amp \ No newline at end of file diff --git a/examples/launch_sbatch.py b/examples/launch_sbatch.py new file mode 100644 index 0000000..5f243ba --- /dev/null +++ b/examples/launch_sbatch.py @@ -0,0 +1,532 @@ +""" +Unified SLURM launcher for all EB-JEPA examples. + +Provides seed averaging, sweep name filtering, and wandb sweep features for all examples. + +USAGE: +------ +# Launch 3 seeds of a single configuration (default sweep name: sweep_YYYYMMDD_HHMM): +python -m examples.launch_sbatch --example ac_video_jepa + +# Launch 3 seeds with custom sweep name: +python -m examples.launch_sbatch --example ac_video_jepa --sweep my_experiment + +# Launch full hyperparameter sweep (ac_video_jepa only): +python -m examples.launch_sbatch --example ac_video_jepa --sweep my_experiment --full-sweep + +# With wandb sweep UI for hyperparameter analysis: +python -m examples.launch_sbatch --example ac_video_jepa --sweep my_experiment --use-wandb-sweep + +# Override config values: +python -m examples.launch_sbatch --example ac_video_jepa --optim.lr 0.0005 + +SEED AVERAGING IN WANDB UI: +--------------------------- +Runs with the same hyperparameters but different seeds share the same wandb run name. + +To view averaged metrics: +1. Go to wandb web UI -> Runs table +2. Click "Group by" -> select "Name" + -> This groups runs with identical hyperparameters (different seeds) together + +To filter runs from a specific sweep: +3. Click "Filter" -> "Group" -> select your sweep name (e.g., 'my_experiment') + -> This shows only runs from that sweep, grouped by name (see above) + +WANDB SWEEP ANALYSIS UI (requires --use-wandb-sweep): +----------------------------------------------------- +When using --use-wandb-sweep, wandb creates a sweep object that enables advanced +hyperparameter analysis. + +To access the sweep analysis: +1. Go to wandb web UI -> left pane -> click "Sweeps" +2. Click on your sweep name +3. Wandb automatically generates plots linking hyperparameters to the metric + (success_rate), including: + - Parallel coordinates plot + - Hyperparameter importance + - Parameter vs. metric scatter plots +""" + +import argparse +import importlib +import json +import os +import shutil +from itertools import product + +import submitit +import wandb + +from eb_jepa.training_utils import ( + get_checkpoints_dir, + get_default_dev_name, + get_default_sweep_name, + get_exp_name, + get_unified_experiment_dir, + load_config, +) + +# Default SLURM parameters +SLURM_DEFAULTS = { + "mem_per_gpu": "210G", + "cpus_per_task": 16, + "timeout_min": 24 * 60, + "partition": "learn", + "gpus_per_node": 1, + "qos": "lowest", + "account": "fair_amaia_cw_video", +} + + +# Example-specific configurations +EXAMPLE_CONFIGS = { + "image_jepa": { + "config": "examples/image_jepa/cfgs/default.yaml", + "module": "examples.image_jepa.main", + "metric": "val_acc", + }, + "video_jepa": { + "config": "examples/video_jepa/cfgs/default.yaml", + "module": "examples.video_jepa.main", + "metric": "AP_1", + }, + "ac_video_jepa": { + "config": "examples/ac_video_jepa/cfgs/train.yaml", + "module": "examples.ac_video_jepa.main", + "metric": "success_rate", + }, +} + +# ============================================================================= +# Utility functions +# ============================================================================= + + +def make_executor( + folder: str, + job_name: str, + array_parallelism: int | None = None, +) -> submitit.AutoExecutor: + """Create a submitit executor with standard SLURM parameters.""" + executor = submitit.AutoExecutor(folder=folder, slurm_max_num_timeout=20) + + params = { + "name": job_name, + "slurm_mem_per_gpu": SLURM_DEFAULTS["mem_per_gpu"], + "cpus_per_task": SLURM_DEFAULTS["cpus_per_task"], + "timeout_min": SLURM_DEFAULTS["timeout_min"], + "slurm_partition": SLURM_DEFAULTS["partition"], + "slurm_additional_parameters": { + "nodes": 1, + "ntasks-per-node": 1, + "gpus-per-node": SLURM_DEFAULTS["gpus_per_node"], + "qos": SLURM_DEFAULTS["qos"], + "account": SLURM_DEFAULTS["account"], + }, + } + + if array_parallelism is not None: + params["slurm_array_parallelism"] = array_parallelism + + executor.update_parameters(**params) + return executor + + +def normalize_sweep_name(name: str) -> str: + """Ensure sweep name has 'sweep_' prefix for consistency.""" + if name.startswith("sweep_"): + return name + return f"sweep_{name}" + + +def copy_code_folder(code_folder): + """Copy the code folder to the experiment directory, ignoring unnecessary files.""" + # Patterns to always ignore (matched by name only) + ignore_patterns = [ + "__pycache__", + ".vscode", + ".git", + "core", + "uv.lock", + "Makefile", + ] + # Paths to ignore (matched by name only, applies to any directory with this name) + ignore_paths = [ + "traces", + "docs", + ".pytest_cache", + "logs", + ".venv", + "eb_jepa.egg-info", + "wandb", + "assets", + ] + # Root-level directories to ignore (only ignored when at the source root) + # This allows us to skip ./datasets (storage-intensive data) while keeping + # ./eb_jepa/datasets (data code needed for experiments) + root_only_ignore = [ + "eb_jepa_ICLR", + "datasets", + "checkpoints", + ] + source_root = os.path.abspath(".") + + def ignore_func(path, names): + ignored = [] + for n in names: + if n in ignore_patterns or n in ignore_paths: + ignored.append(n) + # Only ignore root-level directories specified in root_only_ignore + elif n in root_only_ignore and os.path.abspath(path) == source_root: + ignored.append(n) + return ignored + + if not os.path.exists(code_folder): + shutil.copytree(".", code_folder, ignore=ignore_func) + + +def setup_launch_environment(base_dir, logs_subdir: str | None = "slurm_logs"): + """Setup directories and code folder for launching jobs.""" + base_dir = base_dir.absolute() if hasattr(base_dir, "absolute") else base_dir + logs_dir = base_dir / logs_subdir if logs_subdir else base_dir + code_folder = base_dir / "code" + + copy_code_folder(str(code_folder)) + logs_dir.mkdir(parents=True, exist_ok=True) + + print(f"Code folder: {code_folder}") + os.chdir(code_folder) + + return logs_dir, code_folder + + +def generate_param_combinations(param_grid: dict): + """Generate all parameter combinations from a grid.""" + param_names = list(param_grid.keys()) + param_values_list = list(param_grid.values()) + all_combinations = list(product(*param_values_list)) + return param_names, all_combinations + + +def print_submission_summary(jobs: list, logs_dir, extra_info: dict | None = None): + """Print a compact summary of batch job submission.""" + job_ids = [job.job_id for job in jobs] + batch_id = job_ids[0].split("_")[0] if "_" in job_ids[0] else job_ids[0] + print(f"\nβœ“ Submitted {len(jobs)} jobs (batch {batch_id}_[0-{len(jobs)-1}])") + print(f" Logs: {logs_dir}") + if extra_info: + for key, value in extra_info.items(): + print(f" {key}: {value}") + + +# ============================================================================= +# Launch functions +# ============================================================================= + + +def run_experiment(example_name: str, cfg, folder=None): + """Run the appropriate example based on example_name.""" + print(f"Current working directory: {os.getcwd()}") + print(f"EBJEPA_DSETS: {os.environ.get('EBJEPA_DSETS', 'not set')}") + module = importlib.import_module(EXAMPLE_CONFIGS[example_name]["module"]) + return module.run(cfg=cfg, folder=folder) + + +def launch_job(example_name: str, fname: str, **kwargs): + """Launch a single training job with the given config and overrides.""" + cfg = load_config(fname, kwargs) + sweep_name = kwargs.get("sweep_name", get_default_sweep_name()) + exp_name = get_exp_name(example_name, cfg) + + folder = get_unified_experiment_dir( + example_name=example_name, + sweep_name=sweep_name, + exp_name=exp_name, + seed=cfg.meta.seed, + ) + + logs_dir, _ = setup_launch_environment(folder, logs_subdir=None) + + executor = make_executor( + folder=str(logs_dir), + job_name=f"{example_name.upper()}", + ) + job = executor.submit(run_experiment, example_name, cfg, folder) + + print(f"\nβœ“ Submitted job {job.job_id}") + print(f" Experiment folder: {folder}") + + return job + + +def create_wandb_sweep_config(param_grid: dict, metric: str, method: str = "grid"): + """Create a wandb sweep configuration from a parameter grid.""" + sweep_config = { + "method": method, + "metric": {"goal": "maximize", "name": metric}, + "parameters": {}, + } + + for param_name, param_values in param_grid.items(): + if isinstance(param_values, list): + sweep_config["parameters"][param_name] = {"values": param_values} + elif isinstance(param_values, dict): + sweep_config["parameters"][param_name] = param_values + + return sweep_config + + +def launch_sweep( + example_name: str, + fname: str, + param_grid: dict, + array_parallelism: int = 256, + use_wandb: bool = False, + wandb_method: str = "grid", + **base_overrides, +): + """Launch a parameter sweep using submitit. Returns (sweep_id, jobs) if use_wandb else jobs.""" + param_names, all_combinations = generate_param_combinations(param_grid) + + if not all_combinations: + print("No parameter combinations to sweep") + return (None, []) if use_wandb else [] + + sweep_name = base_overrides.get("sweep_name", get_default_sweep_name()) + + # Create wandb sweep if requested + sweep_id = None + if use_wandb: + project_name = "eb_jepa" + metric = EXAMPLE_CONFIGS[example_name]["metric"] + sweep_config = create_wandb_sweep_config(param_grid, metric, wandb_method) + sweep_id = wandb.sweep(sweep_config, project=project_name) + print(f"Created wandb sweep with ID: {sweep_id}") + print( + f"View sweep at: https://wandb.ai/{wandb.api.default_entity}/{project_name}/sweeps/{sweep_id}" + ) + + # Setup environment (must happen before chdir) + common_dir = get_checkpoints_dir() / example_name / sweep_name + logs_subdir = "wandb_sweep_slurm_logs" if use_wandb else "sweep_slurm_logs" + logs_dir, _ = setup_launch_environment(common_dir, logs_subdir=logs_subdir) + + # Store checkpoints dir before chdir (for absolute paths in job configs) + original_checkpoints_dir = common_dir.parent.parent.absolute() + + executor = make_executor( + folder=str(logs_dir), + job_name=f"{example_name.upper()}_{'wandb_' if use_wandb else ''}sweep", + array_parallelism=array_parallelism, + ) + + print(f"\nPreparing {len(all_combinations)} tasks...") + jobs = [] + with executor.batch(): + for values in all_combinations: + param_overrides = dict(zip(param_names, values)) + final_overrides = {**base_overrides, **param_overrides} + + # Add wandb-specific overrides + if use_wandb: + final_overrides.update( + { + "logging.wandb_sweep": True, + "logging.wandb_sweep_id": sweep_id, + "logging.wandb_group": sweep_name, + } + ) + + cfg = load_config(fname, final_overrides, quiet=True) + exp_name = get_exp_name(example_name, cfg) + folder = get_unified_experiment_dir( + example_name=example_name, + sweep_name=sweep_name, + exp_name=exp_name, + seed=cfg.meta.seed, + base_dir=original_checkpoints_dir, + ) + + job = executor.submit(run_experiment, example_name, cfg, folder) + jobs.append(job) + + extra_info = {"Sweep ID": sweep_id} if use_wandb else None + print_submission_summary(jobs, logs_dir, extra_info) + + return (sweep_id, jobs) if use_wandb else jobs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Unified SLURM launcher for EB-JEPA examples" + ) + parser.add_argument( + "--example", + type=str, + required=True, + choices=["image_jepa", "video_jepa", "ac_video_jepa"], + help="Which example to run", + ) + parser.add_argument( + "--fname", + type=str, + default=None, + help="Path to config file (defaults to example's default config)", + ) + parser.add_argument( + "--sweep", + type=str, + default=None, + help="Name for the sweep (default: sweep_YYYYMMDD_HHMM)", + ) + parser.add_argument( + "--array-parallelism", + type=int, + default=256, + help="Number of jobs to run in parallel for the sweep", + ) + parser.add_argument( + "--use-wandb-sweep", + action="store_true", + help="Use wandb sweep for hyperparameter tracking", + ) + parser.add_argument( + "--sweep-method", + type=str, + default="grid", + choices=["grid", "random", "bayes"], + help="Wandb sweep method to use if use_wandb_sweep is true", + ) + parser.add_argument( + "--full-sweep", + action="store_true", + help="Enable full hyperparameter sweep (default: only sweep over 3 seeds)", + ) + parser.add_argument( + "--single", + action="store_true", + help="Launch a single job (uses dev_YYYYMMDD_HHMM folder)", + ) + + # Common overrides + parser.add_argument("--optim.lr", type=float) + parser.add_argument("--meta.seed", type=int) + + # ac_video_jepa specific + parser.add_argument("--model.regularizer.cov_coeff", type=float) + parser.add_argument("--model.regularizer.std_coeff", type=float) + parser.add_argument("--model.regularizer.sim_coeff_t", type=float) + parser.add_argument("--model.regularizer.idm_coeff", type=float) + + # Use parse_known_args to allow dynamic overrides for any config key + args, unknown = parser.parse_known_args() + + example_name = args.example + example_config = EXAMPLE_CONFIGS[example_name] + fname = args.fname or example_config["config"] + + # Load config to read sweep params from YAML (quiet mode to avoid duplicate logs) + base_cfg = load_config(fname, {}, quiet=True) + + # Read sweep param_grid from config file + # Fall back to default 3-seed sweep if not specified in config + config_param_grid = base_cfg.get("sweep", {}).get("param_grid", {}) + if hasattr(config_param_grid, "to_dict"): + config_param_grid = config_param_grid.to_dict() + elif hasattr(config_param_grid, "__dict__"): + # OmegaConf DictConfig - convert to plain dict + config_param_grid = dict(config_param_grid) + + default_seed_sweep = {"meta.seed": [1, 1000, 10000]} + + # Build overrides dict from known args + excluded_keys = { + "example", + "fname", + "sweep", + "array_parallelism", + "use_wandb_sweep", + "sweep_method", + "full_sweep", + "single", + } + overrides = { + k: v for k, v in vars(args).items() if v is not None and k not in excluded_keys + } + + # Parse unknown args as additional config overrides (e.g., --data.batch_size 64) + i = 0 + while i < len(unknown): + if unknown[i].startswith("--"): + key = unknown[i][2:] + if i + 1 < len(unknown) and not unknown[i + 1].startswith("--"): + value = unknown[i + 1] + # Try to parse as JSON (handles numbers, bools, lists) + try: + value = json.loads(value) + except json.JSONDecodeError: + pass # Keep as string + overrides[key] = value + i += 2 + else: + # Flag without value (e.g., --some_flag) + overrides[key] = True + i += 1 + else: + i += 1 + + # Determine folder name based on mode + if args.single: + # Single job: use dev_ prefix + sweep_name = get_default_dev_name() + param_grid = None # No sweep, single job + elif args.sweep: + # Custom sweep name: normalize to have sweep_ prefix + sweep_name = normalize_sweep_name(args.sweep) + if args.full_sweep: + param_grid = config_param_grid if config_param_grid else default_seed_sweep + else: + param_grid = default_seed_sweep + else: + # Default: 3-seed sweep with auto-generated name + sweep_name = get_default_sweep_name() + if args.full_sweep: + param_grid = config_param_grid if config_param_grid else default_seed_sweep + else: + param_grid = default_seed_sweep + + overrides["sweep_name"] = sweep_name + overrides["logging.wandb_group"] = sweep_name + + print(f"Example: {example_name}") + print(f"Config: {fname}") + print(f"Sweep name: {sweep_name}") + if param_grid: + print(f"Param grid: {param_grid}") + else: + print("Mode: single job") + if overrides: + print(f"Overrides: {overrides}") + + if args.single: + # Launch single job + job = launch_job(example_name, fname, **overrides) + elif args.use_wandb_sweep: + sweep_id, jobs = launch_sweep( + example_name, + fname, + param_grid, + array_parallelism=args.array_parallelism, + use_wandb=True, + wandb_method=args.sweep_method, + **overrides, + ) + else: + jobs = launch_sweep( + example_name, + fname, + param_grid, + array_parallelism=args.array_parallelism, + **overrides, + ) diff --git a/examples/video_jepa/README.md b/examples/video_jepa/README.md index f2f3786..5f19026 100644 --- a/examples/video_jepa/README.md +++ b/examples/video_jepa/README.md @@ -1,143 +1,89 @@ # Self-Supervised Representation Learning from Video Sequences -This example demonstrates Joint Embedding Predictive Architecture (JEPA) for self-supervised representation learning on video sequences. The model learns to predict future video representations from past observations without requiring labeled data. - -## Overview - -JEPA learns representations by training an encoder to map observations to a latent space and a predictor to predict future representations. The key insight is that good representations should enable better prediction of future states. - +This example demonstrates Joint Embedding Predictive Architecture (JEPA) for self-supervised representation learning on moving MNIST video sequences. The videos show two hand-written digits (0-9) moving around the video frame and bouncing off the edges using standard collision physics. The JEPA is trained to predict future video representations from past observations without requiring labeled data, effectively predicting the future trajectory of the digits, in representation space. ![Video JEPA Architecture](assets/arch_figure.png) +### Overview +The video JEPA architecture consists of: +- **Encoder (`ResNet5`)**: A lightweight ResNet that maps video frames to latent representations. Uses residual blocks for stable gradient flow. +- **Predictor (`ResUNet`)**: A UNet-based architecture that predicts future representations given concatenated current and previous representations. Skip connections preserve spatial information. +- **Projector (`MLP`)**: A simple network that projects representations before applying a cost function. -## Architecture - -### Components - - -1. **Encoder (`ResNet5`)**: A lightweight ResNet that maps video frames to latent representations - - Input: Video frames `(B, C, T, H, W)` - - Output: Latent representations `(B, dstc, T, H, W)` - - Uses residual blocks for stable gradient flow - -2. **Predictor (`ResUNet`)**: A UNet-based architecture that predicts future representations - - Input: Concatenated current and previous representations - - Output: Predicted future representations - - Skip connections preserve spatial information - -3. **Projector**: MLP that projects representations before applying a cost function - - Enables comparing projecctions instead of encoder representations - -### Training Objectives - -- **Prediction Loss**: - Minimizes prediction error between predicted and actual future representations -- **VC Loss (Variance-Covariance)**: Regularizes representations to prevent collapse - -## Self-Supervised Learning Approach - -The model learns representations through: - -1. **Temporal Prediction**: Given current and previous representations, predict the next representation -2. **Multi-step Rollout**: Extend predictions to multiple future time steps -3. **Collapse Prevention**: Ensure representations do not collapse using Variance-Covariance Loss. - -This approach forces the encoder to learn features that capture: -- Object motion and dynamics -- Spatial relationships -- Temporal consistency +### Training +Training works as follows. First, video frames are encoded into representation space. Then the past `K` frame encodings are taken as context and provided to the predictor. The predictor predicts the `K+1` th frame representation. Both the encoder and predictor are trained together with the following objectives: -### How the Predictor Achieves Coherent Temporal Dynamics +- **Prediction Loss**: Minimizes prediction error between predicted and actual future representations. The prediction loss supports multi-step rollout prediction to predict `T` steps into the future autoregressively, ensuring temporal consistency of predictions. +- **VC Loss (Variance-Covariance)**: Regularizes representations to prevent collapse. The VC loss is parametrized by two weight factors `std_coeff` and `cov_coeff` for the variance and covariance terms respectively. -The predictor maintains coherent temporal dynamics through several key mechanisms: +### Evaluation +We evaluate the trained JEPA models using a simple digit detection task. Specifically, we train a frame-level decoder to predict the discretized location of a digit, given its encoder representations (BCE loss). We then use the JEPA to generate the future video representations autoregressively, and predict and evaluate the digit locations using the decoder. We also train a pixel-level decoder for the JEPA representations for visualization purposes (MSE loss). Note that both decoders are trained independent to the JEPA training (i.e., gradients are detached). -1. **Context-Aware Input**: The predictor receives concatenated representations from the current and previous time steps `[z_t, z_{t-1}]`, providing temporal context about recent motion and state changes. +We report average precision for digit locations at `T` timesteps into the future (`AP_T`) as well as the average over all prediction horizons (`mAP`). -2. **ResUNet Architecture**: The UNet structure with skip connections preserves spatial information while the residual blocks maintain gradient flow, allowing the model to learn both local and global temporal patterns. - -3. **Autoregressive Rollout**: During inference, the predictor operates in an autoregressive manner: - - Predicts the next representation: `z_{t+1} = f(z_t, z_{t-1})` - - Uses this prediction as input for the next step: `z_{t+2} = f(z_{t+1}, z_t)` - - This creates a chain of predictions that maintain temporal consistency - -4. **Representation Regularization**: The VC loss ensures representations don't collapse to trivial values, maintaining the rich temporal information needed for coherent predictions. - -5. **Motion Understanding**: By training on sequences with consistent physics (bouncing digits), the model learns to encode motion vectors and trajectory information that naturally extend into the future. - -This combination allows the predictor to generate temporally coherent sequences that respect the underlying dynamics of the moving objects, even when extending predictions multiple steps into the future. - -## Usage +## Setup +The example uses Moving MNIST, a synthetic dataset with multiple digits moving across the screen, motion follows simple physics (bouncing off boundaries). **Note**: The Moving MNIST dataset (~800MB) is automatically downloaded on first run from the University of Toronto servers. This requires internet access. If you're running on a cluster without internet access, you can manually download the dataset: ```bash -python main.py \ - --batch_size 64 \ - --dobs 1 \ - --henc 32 \ - --hpre 32 \ - --dstc 16 \ - --steps 1 \ - --cov_coeff 10.0 \ - --std_coeff 200.0 \ - --epochs 100 \ - --lr 5e-4 +wget https://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy -P datasets/ ``` -### Key Parameters - -- `dobs`: Input observation dimensions -- `henc`: Hidden dimensions in encoder -- `hpre`: Hidden dimensions in predictor -- `dstc`: Output representation dimensions -- `steps`: Number of prediction steps during training -- `cov_coeff`: Variance-covariance loss coefficient -- `std_coeff`: Standard deviation loss coefficient -## Dataset - -The example uses Moving MNIST, a synthetic dataset where: -- Multiple digits move across the screen -- Motion follows simple physics (bouncing off boundaries) -- Each video has a fixed number of frames -- Ground truth digit locations are available for evaluation - -**Note**: The Moving MNIST dataset (~800MB) is automatically downloaded on first run from the University of Toronto servers. This requires internet access. If you're running on a cluster without internet access, you can manually download the dataset: +## Training +Train a model with the default configuration (`cov_coeff=100, std_coeff=10, nsteps=4`): ```bash -# Download from a machine with internet access -wget https://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy -P datasets/ +python -m examples.video_jepa.main \ + --fname examples/video_jepa/cfgs/default.yaml +``` -# The file should be placed at datasets/mnist_test_seq.npy before running the example +You can override any config parameter using dot notation: +```bash +python -m examples.video_jepa.main \ + --fname examples/video_jepa/cfgs/default.yaml \ + model.steps=2 \ + data.batch_size=128 ``` -## Evaluation +This will train all components and decoders and visualize results in wandb. -The model is evaluated on: -- **Reconstruction Loss**: How well the decoder can reconstruct input from representations -- **Detection Loss**: Performance on digit location prediction task -- **Average Precision (AP)**: Quality of multi-step predictions +**Key parameters** +- `model.dobs`: Input observation dimensions +- `model.henc`: Hidden dimensions in encoder +- `model.hpre`: Hidden dimensions in predictor +- `model.dstc`: Output representation dimensions +- `model.steps`: Number of prediction steps during training +- `loss.cov_coeff`: Variance-covariance loss coefficient +- `loss.std_coeff`: Standard deviation loss coefficient ## Results -### Visualization +![Video JEPA Visualization](assets/viz.gif) + +Visualization from wandb showing input frames, full rollout obtained via auto-regressive prediction and predicted digit detections (blue heatmap overlays) on Moving MNIST data. The JEPA model learns representations that can predict dynamics well. -*Visualization showing input frames, 1-step predictions visaulization, and full rollout, obtained via auto-regressive prediction, on Moving MNIST data.* +### Detection accuracy +We plot the JEPA training losses and detection performance (`mAP`) to show how they evolve to learn stronger representations over time. +![Video JEPA Losses](assets/losses.png) -![Video JEPA Visualization](assets/viz.png) +### Multi-step Prediction -### Multi-step Prediction Analysis +The `model.steps` parameter controls how many future prediction steps are used to calculate the loss, which has a direct impact on how far into the future the JEPA model can reliably predict. -#### K-Step Predictions +![Video JEPA AP vs. t](assets/AP_vs_t.png) -*Analysis showing that recursively predicting more steps achieve significantly better Average Precision (AP) compared to 1-step predictions. This decreases exposure bias, the discrapancy between train and test, and demonstrates the model's improved temporal understanding with longer prediction horizons.* +As expected, predicting the very next timestep is easy (high mAP) while performance further into the future is naturally lower. Recursively predicting more steps achieve significantly better mAP compared to 1-step or 2-step predictions, however this saturates by 8 steps, which approaches the duration of the video (training is done with 10 frame horizon). -![Multi-step Prediction Analysis](assets/multistep_pred.png) +### JEPA loss ablation +We vary the `cov_coeff` and `std_coeff` to see the effect on downstream performance (`mAP`). We find the best configuration to be `cov_coeff=100`, `std_coeff=10`. -## Key Insights +| cov ↓ std β†’ | 1 | 10 | 100 | +|------------|---------:|---------:|---------:| +| 1 | 0.259 | 0.450 | 0.492 | +| 10 | 0.448 | 0.424 | 0.487 | +| 100 | 0.516 | **0.607** | 0.525 | -1. **Representation Quality**: Good representations enable accurate multi-step prediction -2. **Temporal Consistency**: The model learns to maintain consistency across time steps -3. **Motion Understanding**: Representations capture object dynamics and trajectories -4. **Generalization**: Learned features transfer to downstream tasks like detection ## References diff --git a/examples/video_jepa/assets/AP_vs_t.png b/examples/video_jepa/assets/AP_vs_t.png new file mode 100644 index 0000000..5f71ba7 Binary files /dev/null and b/examples/video_jepa/assets/AP_vs_t.png differ diff --git a/examples/video_jepa/assets/arch_figure.png b/examples/video_jepa/assets/arch_figure.png index 5158540..d6f4ce7 100644 Binary files a/examples/video_jepa/assets/arch_figure.png and b/examples/video_jepa/assets/arch_figure.png differ diff --git a/examples/video_jepa/assets/losses.png b/examples/video_jepa/assets/losses.png new file mode 100644 index 0000000..0ef7c85 Binary files /dev/null and b/examples/video_jepa/assets/losses.png differ diff --git a/examples/video_jepa/assets/multistep_pred.png b/examples/video_jepa/assets/multistep_pred.png deleted file mode 100644 index 66109c6..0000000 Binary files a/examples/video_jepa/assets/multistep_pred.png and /dev/null differ diff --git a/examples/video_jepa/assets/viz.gif b/examples/video_jepa/assets/viz.gif new file mode 100644 index 0000000..6eab550 Binary files /dev/null and b/examples/video_jepa/assets/viz.gif differ diff --git a/examples/video_jepa/cfgs/default.yaml b/examples/video_jepa/cfgs/default.yaml new file mode 100644 index 0000000..03437e2 --- /dev/null +++ b/examples/video_jepa/cfgs/default.yaml @@ -0,0 +1,54 @@ +# Video JEPA Training Configuration +# Train a self-supervised video prediction model on Moving MNIST + +meta: + seed: 2025 + device: auto # auto, cuda, or cpu + +data: + dataset: moving_mnist + batch_size: 64 + num_workers: 4 + +model: + # Encoder (ResNet5) + dobs: 1 # Input channels (grayscale) + henc: 32 # Hidden dimension in encoder + dstc: 16 # Output representation dimension + + # Predictor (ResUNet) + hpre: 32 # Hidden dimension in predictor + + # Training + steps: 4 # Number of prediction steps during training + +loss: + # Variance-Covariance regularization + cov_coeff: 100.0 # Covariance loss weight + std_coeff: 10.0 # Standard deviation loss weight + +optim: + epochs: 50 + lr: 1.0e-3 + +logging: + log_wandb: true + wandb_group: # Set to group runs (e.g., for seed averaging) + log_every: 1 # Log every N epochs + save_every: 10 # Save checkpoint every N epochs + tqdm_silent: false # Disable tqdm progress bars + +training: + use_amp: true # Use automatic mixed precision + dtype: float16 # float16 or bfloat16 + +# --- Parameters used when running with --full-sweep +# By default, only the parameters specified above are used +sweep: + # Sweep grid for video JEPA hyperparameter search + param_grid: + data.batch_size: [32, 64] + optim.lr: [0.001, 0.0001, 0.0005] + loss.std_coeff: [1, 10, 100, 200] + loss.cov_coeff: [1, 10, 100, 200] + meta.seed: [1, 1000, 10000] diff --git a/examples/video_jepa/eval.py b/examples/video_jepa/eval.py index 83820c2..ef2b130 100644 --- a/examples/video_jepa/eval.py +++ b/examples/video_jepa/eval.py @@ -3,79 +3,112 @@ import numpy as np import torch import torch.nn.functional as F -from einops import rearrange -from torchvision.utils import make_grid +import wandb +from einops import rearrange, repeat +from PIL import Image, ImageDraw, ImageFont from tqdm import tqdm -import wandb +def add_label_to_video(video, label): + """Add a text label overlay on each frame of a video. + + Args: + video: numpy array of shape (T, H, W, C) in uint8 + label: text string to add + + Returns: + numpy array of shape (T, H, W, C) + """ + font = ImageFont.load_default() + T, H, W, C = video.shape -def visualize( + labeled_frames = [] + for t in range(T): + frame = Image.fromarray(video[t]) + draw = ImageDraw.Draw(frame, "RGBA") + draw.rectangle([0, 0, W, 20], fill=(40, 40, 40, 200)) + draw.text((4, 4), label, fill=(255, 255, 255), font=font) + labeled_frames.append(np.array(frame)) + return np.stack(labeled_frames, axis=0) + + +def visualize_videos( batch, jepa, pixel_decoder, detection_head, + num_samples, ): + """Create visualization videos for wandb logging. + + Returns a list of videos, each with 3 vertically stacked rows: + 1. Ground truth video + 2. Predicted rollout reconstruction + 3. Digit detection overlay + """ x = batch["video"] x_jepa = jepa.encoder(x) T = x.shape[2] - preds = jepa.infern(x, actions=None, nsteps=T - 2) - - # Helper function to scale and convert pixel decoder outputs - def scale_and_convert_to_uint8(tensor): - # Scale from [0,1] to [0,255] and clamp values - scaled = torch.clamp(tensor * 255, 0, 255) - # Convert to uint8 - return scaled.to(torch.uint8) + preds, _ = jepa.unroll( + x, + actions=None, + nsteps=T - 2, + unroll_mode="parallel", + compute_loss=False, + return_all_steps=True, + ) # One step predictions one_step_pred = x_jepa[:, :, 1:].clone() one_step_pred[:, :, 1:] = preds[0] one_step_reconstruction = pixel_decoder.head(one_step_pred) - one_step_reconstruction = scale_and_convert_to_uint8(one_step_reconstruction) # Multi-step rollouts rollout = x_jepa[:, :, 1:].clone() for t in range(1, T - 1): rollout[:, :, t:] = preds[t - 1][:, :, t - 1 :] rollout_reconstruction = pixel_decoder.head(rollout) - rollout_reconstruction = scale_and_convert_to_uint8(rollout_reconstruction) # Location predictions overlaid over rollout as blue heatmap loc_prediction = detection_head.head(rollout) loc_prediction = F.interpolate( loc_prediction, (x.shape[-2], x.shape[-1]), mode="nearest" ) - loc_prediction = rearrange(loc_prediction, "b t h w -> b 1 t h w") - loc_prediction = loc_prediction.repeat(1, 3, 1, 1, 1) + loc_prediction = repeat(loc_prediction, "b t h w -> b c t h w", c=3).clone() loc_prediction[:, :2].fill_(0) - # Convert rollout_reconstruction back to float for overlay calculation - rollout_reconstruction_float = rollout_reconstruction.float() / 255.0 - overlay = 0.2 * rollout_reconstruction_float + 0.8 * loc_prediction - overlay = scale_and_convert_to_uint8(overlay) + # Overlay rollout reconstruction and location predictions + detection_overlay = 0.2 * rollout_reconstruction + 0.8 * loc_prediction + + # Ground truth (skip first frame to align with predictions) + gt = x[:, :, 1:] - # Stack panels horizontally - # Convert x to uint8 for consistency - x_uint8 = scale_and_convert_to_uint8(x[:, :, 1:]) + # Helper function to scale and convert pixel decoder outputs + # to uint8 RGB and return as numpy array for video logging + def scale_and_convert_to_uint8(tensor): + tensor = F.interpolate(tensor, (100, 100), mode="bilinear") + if tensor.shape[0] == 1: + tensor = tensor.repeat(3, 1, 1, 1) + tensor = torch.clamp(tensor * 255, 0, 255).to(torch.uint8) + tensor = rearrange(tensor, "c t h w -> t h w c").cpu().numpy() + return tensor - merged = torch.cat( - [ - x_uint8.repeat(1, 3, 1, 1, 1), - one_step_reconstruction.repeat(1, 3, 1, 1, 1), - rollout_reconstruction.repeat(1, 3, 1, 1, 1), - overlay, - torch.zeros_like(overlay), - ], - dim=3, - ) # (B, C, T, 3*H, W) + rows = [gt, rollout_reconstruction, detection_overlay] + labels = ["Ground truth", "Predicted rollout", "Digit detections"] - # Stack frames vertically - merged = rearrange(merged, "b c t h w -> b c h (t w)") - grid = make_grid([img for img in merged], nrow=1) - return grid + viz_videos = [] + for b in range(num_samples): + videos = [row[b] for row in rows] + videos = [scale_and_convert_to_uint8(video) for video in videos] + videos = [ + add_label_to_video(video, label) for video, label in zip(videos, labels) + ] + videos = [video.transpose(0, 3, 1, 2) for video in videos] + viz_videos.append(np.concatenate(videos, axis=2)) # (T, C, 3*H, W) + + return viz_videos # Run full loop over validation set and compute metrics @@ -104,17 +137,26 @@ def validation_loop(val_loader, jepa, detection_head, pixel_decoder, steps, devi metrics[k].append(v) T = x.shape[2] - preds = jepa.infern(x, actions=None, nsteps=T - 2) + preds, _ = jepa.unroll( + x, + actions=None, + nsteps=T - 2, + unroll_mode="parallel", + compute_loss=False, + return_all_steps=True, + ) scores = detection_head.head.score(preds, loc_map[:, 2:]) for s, score in enumerate(scores): metrics[f"AP_{s}"].append(float(score)) # Aggregate val results and visualize last batch metrics = {k: float(np.mean(v)) for k, v in metrics.items()} - viz = visualize(batch, jepa, pixel_decoder, detection_head) + videos = visualize_videos( + batch, jepa, pixel_decoder, detection_head, num_samples=16 + ) logs = { - "viz": wandb.Image(viz, caption="decoder_viz"), **metrics, + "viz": [wandb.Video(video, fps=4, format="mp4") for video in videos], } print(metrics) diff --git a/examples/video_jepa/main.py b/examples/video_jepa/main.py index a2c4fbf..a0526dc 100644 --- a/examples/video_jepa/main.py +++ b/examples/video_jepa/main.py @@ -1,11 +1,19 @@ +""" +Video JEPA Training Script + +Train a self-supervised video prediction model on Moving MNIST using +Joint Embedding Predictive Architecture (JEPA) with VC regularization. +""" + +from pathlib import Path + import fire -import torch import torch.nn as nn +from omegaconf import OmegaConf from torch.optim import Adam from torch.utils.data import DataLoader from tqdm import tqdm -import wandb from eb_jepa.architectures import ( DetHead, Projector, @@ -16,76 +24,178 @@ from eb_jepa.datasets.moving_mnist import MovingMNISTDet from eb_jepa.image_decoder import ImageDecoder from eb_jepa.jepa import JEPA, JEPAProbe +from eb_jepa.logging import get_logger from eb_jepa.losses import SquareLossSeq, VCLoss +from eb_jepa.training_utils import ( + get_default_dev_name, + get_exp_name, + get_unified_experiment_dir, + load_checkpoint, + load_config, + log_config, + log_data_info, + log_epoch, + log_model_info, + save_checkpoint, + setup_device, + setup_seed, + setup_wandb, +) from examples.video_jepa.eval import validation_loop +logger = get_logger(__name__) + def run( - batch_size: int = 64, - dobs: int = 1, - henc: int = 32, - hpre: int = 32, - dstc: int = 16, - steps: int = 4, - cov_coeff: float = 100.0, - std_coeff: float = 10.0, - epochs: int = 100, - lr: float = 1e-3, + fname: str = "examples/video_jepa/cfgs/default.yaml", + cfg=None, + folder=None, + **overrides, ): - """Train a Video JEPA model with VC loss. Evaluate the encoder and predictor on Moving MNIST detection, and visualize the representations.""" - device = "cuda" - torch.manual_seed(2025) + """ + Train a Video JEPA model on Moving MNIST. + + Args: + fname: Path to YAML config file + cfg: Pre-loaded config object (optional, overrides config file) + folder: Experiment folder path (optional, auto-generated if not provided) + **overrides: Config overrides in dot notation (e.g., model.lr=0.001) + """ + # Load config + if cfg is None: + cfg = load_config(fname, overrides if overrides else None) + + # Setup + device = setup_device(cfg.meta.device) + setup_seed(cfg.meta.seed) + + # Create experiment directory using unified structure (if not provided) + if folder is None: + if cfg.meta.get("model_folder"): + exp_dir = Path(cfg.meta.model_folder) + folder_name = exp_dir.name + exp_name = folder_name.rsplit("_seed", 1)[0] + else: + sweep_name = get_default_dev_name() + exp_name = get_exp_name("video_jepa", cfg) + exp_dir = get_unified_experiment_dir( + example_name="video_jepa", + sweep_name=sweep_name, + exp_name=exp_name, + seed=cfg.meta.seed, + ) + else: + exp_dir = Path(folder) + exp_dir.mkdir(parents=True, exist_ok=True) + # Extract exp_name from folder name by removing _seed{seed} suffix + folder_name = exp_dir.name # e.g., "resnet_std10.0_cov100.0_seed1" + exp_name = folder_name.rsplit("_seed", 1)[0] # e.g., "resnet_std10.0_cov100.0" + + wandb_run = setup_wandb( + project="eb_jepa", + config={"example": "video_jepa", **OmegaConf.to_container(cfg, resolve=True)}, + run_dir=exp_dir, + run_name=exp_name, + tags=["video_jepa", f"seed_{cfg.meta.seed}"], + group=cfg.logging.get("wandb_group"), + enabled=cfg.logging.log_wandb, + sweep_id=cfg.logging.get("wandb_sweep_id"), + ) # Load datasets train_set = MovingMNISTDet(split="train") val_set = MovingMNISTDet(split="val") - train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) - val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False) + train_loader = DataLoader( + train_set, + batch_size=cfg.data.batch_size, + shuffle=True, + num_workers=cfg.data.num_workers, + ) + val_loader = DataLoader( + val_set, + batch_size=cfg.data.batch_size, + shuffle=False, + num_workers=cfg.data.num_workers, + ) + log_data_info( + "MovingMNIST", + len(train_loader), + cfg.data.batch_size, + train_samples=len(train_set), + val_samples=len(val_set), + ) - # Initialize Video JEPA model, conditioned on past observations (not actions) - encoder = ResNet5(dobs, henc, dstc) - predictor_model = ResUNet(2 * dstc, hpre, dstc) + # Initialize Video JEPA model + logger.info("Initializing model...") + encoder = ResNet5(cfg.model.dobs, cfg.model.henc, cfg.model.dstc) + predictor_model = ResUNet(2 * cfg.model.dstc, cfg.model.hpre, cfg.model.dstc) predictor = StateOnlyPredictor(predictor_model, context_length=2) - projector = Projector(f"{dstc}-{dstc*4}-{dstc*4}") - regularizer = VCLoss(std_coeff, cov_coeff, proj=projector) + projector = Projector(f"{cfg.model.dstc}-{cfg.model.dstc*4}-{cfg.model.dstc*4}") + regularizer = VCLoss(cfg.loss.std_coeff, cfg.loss.cov_coeff, proj=projector) ploss = SquareLossSeq(projector) jepa = JEPA(encoder, encoder, predictor, regularizer, ploss).to(device) - # Initialize decoder and detection head, only used for evaluation - decoder = ImageDecoder(dstc, dobs) - dethead = DetHead(dstc, hpre, dobs) + # Initialize decoder and detection head (for evaluation only) + decoder = ImageDecoder(cfg.model.dstc, cfg.model.dobs) + dethead = DetHead(cfg.model.dstc, cfg.model.hpre, cfg.model.dobs) pixel_decoder = JEPAProbe(jepa, decoder, nn.MSELoss()).to(device) detection_head = JEPAProbe(jepa, dethead, nn.BCELoss()).to(device) + # Log model structure and parameters + encoder_params = sum(p.numel() for p in encoder.parameters()) + predictor_params = sum(p.numel() for p in predictor.parameters()) + log_model_info(jepa, {"encoder": encoder_params, "predictor": predictor_params}) + jepa.train() detection_head.train() pixel_decoder.train() + # Set learning rates for different components + # Lower learning rate for pixel decoder to prevent overfitting optimizer = Adam( [ - {"params": jepa.parameters(), "lr": lr}, - # only for visualization purposes, gradients are not propogated back to the encoder - {"params": pixel_decoder.head.parameters(), "lr": lr}, - {"params": detection_head.head.parameters(), "lr": lr}, + {"params": jepa.parameters(), "lr": cfg.optim.lr}, + {"params": pixel_decoder.head.parameters(), "lr": cfg.optim.lr / 10}, + {"params": detection_head.head.parameters(), "lr": cfg.optim.lr}, ] ) - wandb.init( - project="jepa-unlabeled-video", - config={"batch_size": batch_size, "dobs": dobs, "henc": henc, "hpre": hpre, "dstc": dstc, "steps": steps, "cov_coeff": cov_coeff, "std_coeff": std_coeff, "epochs": epochs, "lr": lr} - ) + # Log configuration + log_config(cfg) - for epoch in range(epochs): - pbar = tqdm(train_loader) - for _, batch in enumerate(pbar): + # Load checkpoint if requested + start_epoch = 0 + global_step = 0 + if cfg.meta.get("load_model"): + ckpt_path = exp_dir / cfg.meta.get("load_checkpoint", "latest.pth.tar") + ckpt_info = load_checkpoint(ckpt_path, jepa, optimizer, device=device) + start_epoch = ckpt_info.get("epoch", 0) + global_step = ckpt_info.get("step", 0) - batch = {k: v.to(device) for k, v in batch.items()} + # Training loop + logger.info(f"Starting training for {cfg.optim.epochs} epochs...") + + for epoch in range(start_epoch, cfg.optim.epochs): + pbar = tqdm( + train_loader, + desc=f"Epoch {epoch}", + disable=cfg.logging.get("tqdm_silent", False), + ) + for batch in pbar: + batch = {k: v.to(device) for k, v in batch.items()} x = batch["video"] loc_map = batch["digit_location"] optimizer.zero_grad() - jepa_loss, regl, _, regldict, pl = jepa.forwardn(x, actions=None, nsteps=steps) + _, (jepa_loss, regl, _, regldict, pl) = jepa.unroll( + x, + actions=None, + nsteps=cfg.model.steps, + unroll_mode="parallel", + compute_loss=True, + return_all_steps=False, + ) recon_loss = pixel_decoder(x, x) det_loss = detection_head(x, loc_map) total_loss = jepa_loss + recon_loss + det_loss @@ -93,26 +203,75 @@ def run( total_loss.backward() optimizer.step() - logs = { - "Epoch": epoch, - "Loss": float(jepa_loss.item()), - "VC Loss": float(regl.item()), - "Pred Loss": float(pl.item()), - "Recon Loss": float(recon_loss.item()), - "Det Loss": float(det_loss.item()), + # Update progress bar + pbar.set_postfix( + { + "loss": f"{jepa_loss.item():.4f}", + "vc": f"{regl.item():.4f}", + "pred": f"{pl.item():.4f}", + } + ) + + global_step += 1 + + # Validation and logging + if epoch % cfg.logging.log_every == 0: + val_logs = validation_loop( + val_loader, jepa, detection_head, pixel_decoder, cfg.model.steps, device + ) + + train_metrics = { + "epoch": epoch, + "train/loss": jepa_loss.item(), + "train/vc_loss": regl.item(), + "train/pred_loss": pl.item(), + "train/recon_loss": recon_loss.item(), + "train/det_loss": det_loss.item(), } for k, v in regldict.items(): - logs[k] = float(v) + train_metrics[f"train/{k}"] = float(v) + + all_metrics = {**train_metrics, **val_logs} + + if wandb_run: + import wandb + + wandb.log(all_metrics, step=global_step) - pbar.set_postfix(logs) + log_epoch( + epoch, + { + "loss": jepa_loss.item(), + "vc": regl.item(), + "pred": pl.item(), + "val_recon": val_logs.get("val/recon_loss", 0), + }, + total_epochs=cfg.optim.epochs, + ) - # Log train results every epoch - step = len(train_loader) * epoch - val_logs = validation_loop( - val_loader, jepa, detection_head, pixel_decoder, steps, device + # Save checkpoint + save_checkpoint( + exp_dir / "latest.pth.tar", + model=jepa, + optimizer=optimizer, + epoch=epoch, + step=global_step, ) - wandb.log({**logs, **val_logs}, step=step) - wandb.finish() + if epoch % cfg.logging.save_every == 0 and epoch > 0: + save_checkpoint( + exp_dir / f"epoch_{epoch}.pth.tar", + model=jepa, + optimizer=optimizer, + epoch=epoch, + step=global_step, + ) + + if wandb_run: + import wandb + + wandb.finish() + + logger.info("Training complete!") if __name__ == "__main__": diff --git a/examples/video_jepa/run_exp.sh b/examples/video_jepa/run_exp.sh deleted file mode 100644 index 1492eaa..0000000 --- a/examples/video_jepa/run_exp.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name=NWM -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=12 -#SBATCH --time=12:00:00 -#SBATCH --partition=dev -#SBATCH --signal=B:CONT@60 -#SBATCH --requeue -#SBATCH --output=/checkpoint/amirbar/experiments/eb_jepa/logs/%A_%a.out -#SBATCH --error=/checkpoint/amirbar/experiments/eb_jepa/logs/%A_%a.err -#SBATCH --array=0-17 - -# BATCH_SIZES=(8 16 32 64) -BATCH_SIZES=(512 1024) -LRS=(1e-3 1e-4 5e-4) -STD=(1 10 100 200) -COV=(1 10 100 200) -# COV=(100) - -chmod a+x ~/.bashrc -PS1='$ ' -source ~/.bashrc -cd "/private/home/amirbar/projects/eb_jepa_internal" -echo ${exp_list[$SLURM_ARRAY_TASK_ID]} - -triplets=() - -for item1 in "${BATCH_SIZES[@]}"; do - for item2 in "${LRS[@]}"; do - for item3 in "${STD[@]}"; do - for item4 in "${COV[@]}"; do - triplets+=("($item1, $item2, $item3, $item4)") - done - done - done -done - - -triplet="${triplets[$SLURM_ARRAY_TASK_ID]}" -bs=$(echo "$triplet" | awk -F '[(), ]+' '{print $2}') -lr=$(echo "$triplet" | awk -F '[(), ]+' '{print $3}') -std=$(echo "$triplet" | awk -F '[(), ]+' '{print $4}') -cov=$(echo "$triplet" | awk -F '[(), ]+' '{print $5}') - -/private/home/amirbar/projects/eb_jepa_internal/.venv/bin/python -m examples.image_jepa.main \ - --batch_size=${bs} \ - --lr=${lr} \ - --std_coeff=${std} \ - --cov_coeff=${cov} \ - --epochs=100 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d3a26b2..66e04c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,21 @@ [project] -name = "eb-jepa" +name = "eb_jepa" version = "0.1.1" description = "Energy-Based JEPA" -authors = [{name = "Yann LeCun", email = "ylecun@meta.com"}, {name = "Amir Bar", email = "amirbar@meta.com"}, {name= "Koustuv Sinha", email="koustuvs@meta.com"}] +authors = [ + {name = "Basile Terver", email = "basileterv@meta.com"}, + {name = "Randall Balestriero", email = "rbalestriero@meta.com"}, + {name = "Megi Dervishi", email = "megidervishi@meta.com"}, + {name = "David Fan", email = "davidfan@meta.com"}, + {name = "Quentin Garrido", email = "garridoq@meta.com"}, + {name = "Tushar Nagarajan", email = "tusharn@meta.com"}, + {name = "Koustuv Sinha", email = "koustuvs@meta.com"}, + {name = "Wancong Zhang", email = "wz1232@nyu.edu"}, + {name = "Mike Rabbat", email = "mikerabbat@meta.com"}, + # Equal contribution as senior authors + {name = "Yann LeCun", email = "yann@meta.com"}, + {name = "Amir Bar", email = "amirbar@meta.com"}, +] dependencies = [ "einops>=0.8.1", "fire>=0.7.0", @@ -15,7 +28,7 @@ dependencies = [ "torch==2.6.0", "torchcodec>=0.4.0", "torchvision", - "wandb>=0.21.1", + "wandb[media]>=0.21.1", "gymnasium>=1.1.1", "imageio", "seaborn", @@ -44,4 +57,5 @@ dev = [ "black>=25.1.0", "isort>=6.0.1", "pytest>=8.4.1", + "autoflake", ] diff --git a/tests/eb_jepa_test.py b/tests/eb_jepa_test.py deleted file mode 100644 index 9be1bb3..0000000 --- a/tests/eb_jepa_test.py +++ /dev/null @@ -1,278 +0,0 @@ -import pytest -import torch - -from eb_jepa.losses import ContrastiveLoss, contrastive_cost, vc_cost_chunked - - -@pytest.mark.skip(reason="Not implemented yet") -def test_forwardn(): - pass - - -@pytest.mark.skip(reason="Not implemented yet") -def test_infern(): - pass - - -def test_vc_cost_chunked(): - """Test the vc_cost_chunked function with various scenarios.""" - - # Test parameters - batch_size = 8 - feature_size = 4 - time_steps = 3 - height = 2 - width = 2 - - # Create test input tensor in BFTHW format - x = torch.randn(batch_size, feature_size, time_steps, height, width) - - # Create coefficient matrix for covariance regularization - ccoeff = torch.ones(feature_size, feature_size) * 0.1 - - # Calculate total flattened samples - total_samples = batch_size * time_steps * height * width # 8 * 3 * 2 * 2 = 96 - - # Test 1: Basic functionality with default parameters - result = vc_cost_chunked(x, ccoeff) - - # Check return type is tuple - assert isinstance(result, tuple), "Function should return a tuple" - assert len(result) == 2, "Function should return a tuple of length 2" - - mean_cost, individual_costs = result - - # Check that mean_cost is a scalar tensor - assert isinstance(mean_cost, torch.Tensor), "Mean cost should be a tensor" - assert mean_cost.dim() == 0, "Mean cost should be a scalar (0-dim tensor)" - - # Check that individual_costs is a 1D tensor - assert isinstance( - individual_costs, torch.Tensor - ), "Individual costs should be a tensor" - assert individual_costs.dim() == 1, "Individual costs should be 1D tensor" - - # When bdim=0 (default), it uses feature_size=4, so chunks = 96//4 = 24 - expected_chunks = total_samples // feature_size - assert ( - individual_costs.size(0) == expected_chunks - ), f"Expected {expected_chunks} chunks, got {individual_costs.size(0)}" - - # Check that mean equals the mean of individual costs - assert torch.allclose( - mean_cost, individual_costs.mean() - ), "Mean cost should equal mean of individual costs" - - # Test 2: Explicit chunk size - bdim = 2 - result2 = vc_cost_chunked(x, ccoeff, batch_dim=bdim) - mean_cost2, individual_costs2 = result2 - - expected_chunks2 = total_samples // bdim # 96 // 2 = 48 - assert ( - individual_costs2.size(0) == expected_chunks2 - ), f"Expected {expected_chunks2} chunks with bdim={bdim}" - - # Test 3: Different coefficients - mcoeff = 0.5 - c = 2.0 - result3 = vc_cost_chunked(x, ccoeff, mcoeff=mcoeff, c=c, batch_dim=bdim) - mean_cost3, individual_costs3 = result3 - - # Should have same number of chunks but different costs - assert individual_costs3.size(0) == expected_chunks2 - # Costs should be different due to different parameters - assert not torch.allclose( - mean_cost2, mean_cost3 - ), "Different parameters should produce different costs" - - # Test 4: Edge case - bdim that divides total_samples evenly - bdim4 = 8 - result4 = vc_cost_chunked(x, ccoeff, batch_dim=bdim4) - mean_cost4, individual_costs4 = result4 - - expected_chunks4 = total_samples // bdim4 # 96 // 8 = 12 - assert ( - individual_costs4.size(0) == expected_chunks4 - ), f"Should have {expected_chunks4} chunks when bdim={bdim4}" - - # Test 5: Check that all costs are non-negative (reasonable for a loss function) - assert torch.all( - individual_costs >= 0 - ), "All individual costs should be non-negative" - assert mean_cost >= 0, "Mean cost should be non-negative" - - # Test 6: Edge case - large bdim (fewer chunks) - bdim_large = 32 - result5 = vc_cost_chunked(x, ccoeff, batch_dim=bdim_large) - mean_cost5, individual_costs5 = result5 - - expected_chunks5 = total_samples // bdim_large # 96 // 32 = 3 - assert ( - individual_costs5.size(0) == expected_chunks5 - ), f"Should have {expected_chunks5} chunks when bdim={bdim_large}" - - -def test_contrastive_loss(): - """Test the ContrastiveLoss class and contrastive_cost function with various scenarios.""" - - # Test parameters - batch_size = 4 - feature_size = 8 - time_steps = 2 - height = 3 - width = 3 - - # Create test input tensor in BFTHW format - x = torch.randn(batch_size, feature_size, time_steps, height, width) - - # Test 1: Basic functionality with ContrastiveLoss class - contrastive_loss = ContrastiveLoss(temperature=0.1, negative_weight=1.0) - result = contrastive_loss(x) - - # Check return type is scalar tensor - assert isinstance(result, torch.Tensor), "ContrastiveLoss should return a tensor" - assert result.dim() == 0, "ContrastiveLoss should return a scalar (0-dim tensor)" - assert result >= 0, "Contrastive loss should be non-negative" - - # Test 2: Different temperature values - temp_low = ContrastiveLoss(temperature=0.01, negative_weight=1.0) - temp_high = ContrastiveLoss(temperature=1.0, negative_weight=1.0) - - result_low_temp = temp_low(x) - result_high_temp = temp_high(x) - - # Both should be non-negative - assert ( - result_low_temp >= 0 and result_high_temp >= 0 - ), "Both losses should be non-negative" - - # Test 3: Different negative weights - weight_small = ContrastiveLoss(temperature=0.1, negative_weight=0.5) - weight_large = ContrastiveLoss(temperature=0.1, negative_weight=2.0) - - result_small_weight = weight_small(x) - result_large_weight = weight_large(x) - - # Larger weight should give proportionally larger loss - assert torch.allclose( - result_large_weight, result_small_weight * 4.0 - ), "Loss should scale proportionally with negative_weight" - - # Test 4: Test contrastive_cost function directly - cost_result = contrastive_cost(x, temperature=0.1, negative_weight=1.0) - - assert isinstance( - cost_result, torch.Tensor - ), "contrastive_cost should return a tensor" - assert cost_result.dim() == 0, "contrastive_cost should return a scalar" - assert cost_result >= 0, "Contrastive cost should be non-negative" - - # Test 4b: Test subset sampling - total_samples = batch_size * time_steps * height * width # 4 * 2 * 3 * 3 = 72 - subset_size = 16 - num_subsets = 3 - - cost_subset = contrastive_cost( - x, - temperature=0.1, - negative_weight=1.0, - subset_size=subset_size, - num_subsets=num_subsets, - ) - - assert isinstance( - cost_subset, torch.Tensor - ), "Subset contrastive_cost should return a tensor" - assert cost_subset.dim() == 0, "Subset contrastive_cost should return a scalar" - assert cost_subset >= 0, "Subset contrastive cost should be non-negative" - - # Test 5: With projection layer - proj_layer = torch.nn.Linear(feature_size, feature_size // 2) - contrastive_with_proj = ContrastiveLoss( - temperature=0.1, negative_weight=1.0, proj=proj_layer - ) - - result_with_proj = contrastive_with_proj(x) - assert isinstance( - result_with_proj, torch.Tensor - ), "Loss with projection should return a tensor" - assert result_with_proj.dim() == 0, "Loss with projection should return a scalar" - assert result_with_proj >= 0, "Loss with projection should be non-negative" - - # Test 5b: ContrastiveLoss with subset sampling - contrastive_subset = ContrastiveLoss( - temperature=0.1, - negative_weight=1.0, - subset_size=subset_size, - num_subsets=num_subsets, - ) - result_subset = contrastive_subset(x) - - assert isinstance( - result_subset, torch.Tensor - ), "Subset ContrastiveLoss should return a tensor" - assert result_subset.dim() == 0, "Subset ContrastiveLoss should return a scalar" - assert result_subset >= 0, "Subset contrastive loss should be non-negative" - - # Test 6: Edge case - identical samples should give high loss - x_identical = torch.ones(batch_size, feature_size, time_steps, height, width) - identical_loss = contrastive_loss(x_identical) - - # Identical samples should result in high similarity and thus high contrastive loss - assert ( - identical_loss > 0 - ), "Identical samples should produce positive contrastive loss" - - # Test 7: Edge case - very different samples - x_diverse = torch.randn(batch_size, feature_size, time_steps, height, width) * 10 - # Make samples very different by scaling different samples differently - for i in range(batch_size): - x_diverse[i] = x_diverse[i] * (i + 1) - - diverse_loss = contrastive_loss(x_diverse) - assert diverse_loss >= 0, "Diverse samples should still produce non-negative loss" - - # Test 8: Gradient flow (ensure loss is differentiable) - x_grad = torch.randn( - batch_size, feature_size, time_steps, height, width, requires_grad=True - ) - loss_grad = contrastive_loss(x_grad) - loss_grad.backward() - - assert x_grad.grad is not None, "Gradients should flow through the contrastive loss" - assert not torch.isnan(x_grad.grad).any(), "Gradients should not contain NaN values" - - # Test 9: Consistency between class and function - class_result = contrastive_loss(x) - function_result = contrastive_cost(x, temperature=0.1, negative_weight=1.0) - - assert torch.allclose( - class_result, function_result - ), "ContrastiveLoss class and contrastive_cost function should give same result" - - # Test 10: Subset edge cases - # Test with subset_size larger than total samples - large_subset = ContrastiveLoss( - temperature=0.1, negative_weight=1.0, subset_size=1000, num_subsets=2 - ) - result_large = large_subset(x) - assert result_large >= 0, "Large subset size should still work" - - # Test with single subset - single_subset = ContrastiveLoss( - temperature=0.1, negative_weight=1.0, subset_size=8, num_subsets=1 - ) - result_single = single_subset(x) - assert result_single >= 0, "Single subset should work" - - # Test efficiency: subset sampling should be faster for large inputs - # (This is more of a conceptual test - in practice we'd need larger tensors to see the difference) - efficient_contrastive = ContrastiveLoss( - temperature=0.1, - negative_weight=1.0, - subset_size=min(16, total_samples // 2), - num_subsets=3, - ) - result_efficient = efficient_contrastive(x) - assert result_efficient >= 0, "Efficient subset sampling should work" diff --git a/tests/planning_test.py b/tests/planning_test.py index 2fcf67f..5a76ad8 100644 --- a/tests/planning_test.py +++ b/tests/planning_test.py @@ -108,7 +108,8 @@ def test_gc_agent(mock_cem_planner): # Create a mock model with parameters mock_model = Mock() mock_model.encode = Mock(return_value=torch.zeros(1, 8, 1, 8, 8)) - mock_model.unrolln = Mock(return_value=torch.zeros(10, 8, 6, 8, 8)) + # unroll returns a tuple (predicted_states, loss) - loss is None when compute_loss=False + mock_model.unroll = Mock(return_value=(torch.zeros(10, 8, 6, 8, 8), None)) # Add a parameter method that returns an iterator with a device param = torch.nn.Parameter(torch.zeros(1)) mock_model.parameters = Mock(return_value=iter([param])) @@ -193,7 +194,7 @@ def test_gc_agent(mock_cem_planner): states = agent.unroll(obs_init, actions) # Verify unroll behavior - assert mock_model.unrolln.called, "Model's unroll should be called" + assert mock_model.unroll.called, "Model's unroll should be called" assert isinstance(states, torch.Tensor), "Should return a tensor" @@ -241,8 +242,19 @@ def mock_env_creator(): mock_agent_instance = Mock() mock_agent_instance.act = Mock(return_value=torch.tensor([[0.1, 0.2]])) mock_agent_instance.device = torch.device("cpu") - mock_agent_instance.analyze_distances = Mock(return_value=([], [])) mock_agent_instance.decode_each_iteration = False + mock_agent_instance.num_act_stepped = 1 + # Set up proper tensor values for agent attributes used by analyze_distances + mock_agent_instance.goal_position = torch.tensor([10.0, 10.0]) + mock_agent_instance.goal_state = torch.zeros(2, 65, 65) + mock_agent_instance.normalizer = mock_normalizer = Mock() + mock_normalizer.normalize_state = Mock(side_effect=lambda x: x) + mock_agent_instance.model = Mock() + mock_agent_instance.model.encode = Mock(return_value=torch.zeros(1, 8, 1, 4, 4)) + mock_agent_instance.objective = Mock(return_value=torch.zeros(1)) + mock_agent_instance._prev_losses = None + mock_agent_instance._prev_elite_losses_mean = None + mock_agent_instance._prev_elite_losses_std = None mock_gc_agent.return_value = mock_agent_instance # Create plan config @@ -264,12 +276,12 @@ def mock_env_creator(): } # Run evaluation with fewer episodes for testing - os.makedirs("./logs/", exist_ok=True) + os.makedirs("./tests/logs/", exist_ok=True) results = main_eval( plan_cfg=plan_cfg, model=mock_model, env_creator=mock_env_creator, - eval_folder=Path("./logs/"), + eval_folder=Path("./tests/logs/"), num_episodes=num_episodes, ) @@ -301,10 +313,19 @@ def encode(self, x): B, C, T, H, W = x.shape return torch.zeros(B, 8, T, 4, 4, device=x.device) - def unrolln(self, obs, actions, steps, ctxt_window_time=1): + def unroll( + self, + obs, + actions, + nsteps, + unroll_mode="autoregressive", + ctxt_window_time=1, + compute_loss=False, + return_all_steps=False, + ): # Simpler unroll function that doesn't depend on complex tensor shapes B = obs.shape[0] - return torch.ones(B, 8, steps, 4, 4, device=obs.device) + return torch.ones(B, 8, nsteps, 4, 4, device=obs.device), None # Test full planning episode # Create model and move to appropriate device diff --git a/tests/test_jepa_output_formats.py b/tests/test_jepa_output_formats.py new file mode 100644 index 0000000..f589cbd --- /dev/null +++ b/tests/test_jepa_output_formats.py @@ -0,0 +1,832 @@ +""" +Test script to verify output formats of JEPA's unroll() function. + +This tests: +1. Output format of unroll() function in parallel and autoregressive modes +2. Correct behavior with both RNN and Conv predictors +3. return_all_steps functionality + +Usage patterns: +- unroll(parallel, return_all_steps=True): Multi-step inference (formerly infern()) +- unroll(parallel, compute_loss=True): Training with full GT trajectory +- unroll(autoregressive, compute_loss=False): Planning/MPC (formerly unrolln()) +""" + +import torch +import torch.nn as nn + +from eb_jepa.architectures import ( + ImpalaEncoder, + InverseDynamicsModel, + Projector, + ResNet5, + ResUNet, + RNNPredictor, + StateOnlyPredictor, +) +from eb_jepa.jepa import JEPA +from eb_jepa.losses import SquareLossSeq, VC_IDM_Sim_Regularizer, VCLoss + + +# ============================================================================ +# Helper function to set random seed for reproducibility +# ============================================================================ +def set_seed(seed=42): + """Set random seed for reproducibility in tests.""" + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def create_video_jepa_model(device="cpu"): + """Create a Video JEPA model matching the default config.""" + # Config values from examples/video_jepa/cfgs/default.yaml + dobs = 1 # Input channels (grayscale) + henc = 32 # Hidden dimension in encoder + dstc = 16 # Output representation dimension + hpre = 32 # Hidden dimension in predictor + + encoder = ResNet5(dobs, henc, dstc) + predictor_model = ResUNet(2 * dstc, hpre, dstc) + predictor = StateOnlyPredictor(predictor_model, context_length=2) + projector = Projector(f"{dstc}-{dstc*4}-{dstc*4}") + regularizer = VCLoss(std_coeff=10.0, cov_coeff=100.0, proj=projector) + ploss = SquareLossSeq(projector) + jepa = JEPA(encoder, encoder, predictor, regularizer, ploss).to(device) + + return jepa, dstc + + +def create_ac_video_jepa_model(device="cpu", img_size=65): + """ + Create an Action-Conditioned Video JEPA model matching the + architecture from examples/ac_video_jepa/main.py with default + config from examples/ac_video_jepa/cfgs/train.yaml. + """ + # Config values from examples/ac_video_jepa/cfgs/train.yaml + dobs = 2 # Input channels (RGB + position = 2 channels for two_rooms) + henc = 32 # Hidden dimension in encoder + dstc = 32 # Output representation dimension + nsteps = 8 # Number of prediction steps + action_dim = 2 # Action dimension for two_rooms + + # Regularizer config + cov_coeff = 8 + std_coeff = 16 + sim_coeff_t = 12 + idm_coeff = 1 + first_t_only = False + spatial_as_samples = False + use_proj = False + idm_after_proj = False + sim_t_after_proj = False + + # Create encoder (ImpalaEncoder as in main.py) + encoder = ImpalaEncoder( + width=1, + stack_sizes=(16, henc, dstc), + num_blocks=2, + dropout_rate=None, + layer_norm=False, + input_channels=dobs, + final_ln=True, + mlp_output_dim=512, + input_shape=(dobs, img_size, img_size), + ) + + # Test encoder to get output dimensions + test_input = torch.rand((1, dobs, 1, img_size, img_size)) + test_output = encoder(test_input) + _, f, _, h, w = test_output.shape + + # Create predictor (RNNPredictor as in main.py) + predictor = RNNPredictor( + hidden_size=encoder.mlp_output_dim, + action_dim=action_dim, + final_ln=nn.LayerNorm(encoder.mlp_output_dim) if encoder.final_ln else None, + ) + + # Action encoder is identity + aencoder = nn.Identity() + + # Projector (only if use_proj=True in config) + if use_proj: + projector = Projector( + f"{encoder.mlp_output_dim}-{encoder.mlp_output_dim*4}-{encoder.mlp_output_dim*4}" + ) + else: + projector = None + + # Create IDM (InverseDynamicsModel) + idm = InverseDynamicsModel( + state_dim=h * w * (projector.out_dim if idm_after_proj and projector else f), + hidden_dim=256, + action_dim=action_dim, + ).to(device) + + # Create regularizer (VC_IDM_Sim_Regularizer as in main.py) + regularizer = VC_IDM_Sim_Regularizer( + cov_coeff=cov_coeff, + std_coeff=std_coeff, + sim_coeff_t=sim_coeff_t, + idm_coeff=idm_coeff, + idm=idm, + first_t_only=first_t_only, + projector=projector, + spatial_as_samples=spatial_as_samples, + idm_after_proj=idm_after_proj, + sim_t_after_proj=sim_t_after_proj, + ) + + # Prediction loss + ploss = SquareLossSeq() + + # Create JEPA model + jepa = JEPA(encoder, aencoder, predictor, regularizer, ploss).to(device) + + config = { + "dobs": dobs, + "henc": henc, + "dstc": dstc, + "nsteps": nsteps, + "action_dim": action_dim, + "img_size": img_size, + "mlp_output_dim": encoder.mlp_output_dim, + "encoder_spatial_h": h, + "encoder_spatial_w": w, + } + + return jepa, config + + +# ============================================================================ +# Tests for unroll() function in parallel mode +# ============================================================================ + + +def test_unroll_parallel_mode_output_format(): + """ + Test unroll() output format in parallel mode. + + Usage pattern: + preds, losses = jepa.unroll(x, actions=None, nsteps=nsteps, + unroll_mode="parallel", compute_loss=False, + return_all_steps=True) + """ + print("=" * 60) + print("Testing unroll() parallel mode output format") + print("=" * 60) + + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + jepa, dstc = create_video_jepa_model(device) + jepa.eval() + + # Create fake video input matching Moving MNIST format + # Shape: [B, C, T, H, W] = [batch, channels, time, height, width] + B, C, T, H, W = 4, 1, 10, 64, 64 + x = torch.randn(B, C, T, H, W, device=device) + + print(f"\nInput shape: {x.shape}") + print(f" B (batch) = {B}") + print(f" C (channels) = {C}") + print(f" T (time steps) = {T}") + print(f" H (height) = {H}") + print(f" W (width) = {W}") + + # Call unroll with return_all_steps=True (like former infern()) + nsteps = T - 2 + print(f"\nCalling: jepa.unroll(x, actions=None, nsteps={nsteps}, ...") + print( + f" unroll_mode='parallel', compute_loss=False, return_all_steps=True)" + ) + + with torch.no_grad(): + preds, losses = jepa.unroll( + x, + actions=None, + nsteps=nsteps, + unroll_mode="parallel", + compute_loss=False, + return_all_steps=True, + ) + + print(f"\n--- unroll() Output Analysis ---") + print(f"Return type of preds: {type(preds)}") + print(f"Length of preds list: {len(preds)}") + print(f"Expected length (nsteps): {nsteps}") + print(f"losses: {losses} (expected: None when compute_loss=False)") + + # Analyze each prediction step + print(f"\nPer-step shapes:") + for i, pred in enumerate(preds): + print(f" preds[{i}] shape: {pred.shape}") + + # Verify the expected format + first_pred = preds[0] + context_length = jepa.predictor.context_length + print(f"\n--- Shape Breakdown for preds[0] ---") + print(f" Dimension 0 (batch): {first_pred.shape[0]} (expected: {B})") + print(f" Dimension 1 (embedding dim): {first_pred.shape[1]} (expected: {dstc})") + print( + f" Dimension 2 (time): {first_pred.shape[2]} (expected: T-context_length = {T}-{context_length} = {T-context_length})" + ) + print(f" Dimension 3 (height): {first_pred.shape[3]}") + print(f" Dimension 4 (width): {first_pred.shape[4]}") + + # Assertions + assert isinstance(preds, list), f"Expected list, got {type(preds)}" + assert len(preds) == nsteps, f"Expected {nsteps} steps, got {len(preds)}" + assert losses is None, f"Expected losses=None, got {losses}" + print("\n βœ“ All assertions passed!") + + print("=" * 60) + + return preds + + +def test_unroll_parallel_mode_with_loss(): + """ + Test unroll() output format in parallel mode with loss computation. + + Usage pattern: + _, losses = jepa.unroll(x, actions=None, nsteps=cfg.model.steps, + unroll_mode="parallel", compute_loss=True) + loss, rloss, rloss_unweight, rloss_dict, ploss = losses + """ + print("\n" + "=" * 60) + print("Testing unroll() parallel mode with loss computation") + print("=" * 60) + + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + jepa, dstc = create_video_jepa_model(device) + jepa.train() + + # Create fake video input matching Moving MNIST format + B, C, T, H, W = 4, 1, 10, 64, 64 + x = torch.randn(B, C, T, H, W, device=device) + + print(f"\nInput shape: {x.shape}") + + # Call unroll with compute_loss=True (for training) + nsteps = 4 + print(f"\nCalling: jepa.unroll(x, actions=None, nsteps={nsteps}, ...") + print(f" unroll_mode='parallel', compute_loss=True)") + + predicted_states, losses = jepa.unroll( + x, actions=None, nsteps=nsteps, unroll_mode="parallel", compute_loss=True + ) + loss, rloss, rloss_unweight, rloss_dict, ploss = losses + + print(f"\n--- unroll() Output Analysis ---") + print(f"Output is a tuple of (predicted_states, losses):") + print(f" predicted_states shape: {predicted_states.shape}") + print(f"\nlosses tuple contains 5 elements:") + print( + f" 1. total_loss (loss): {type(loss).__name__}, shape: {loss.shape}, value: {loss.item():.6f}" + ) + print( + f" 2. reg_loss (rloss): {type(rloss).__name__}, shape: {rloss.shape}, value: {rloss.item():.6f}" + ) + print( + f" 3. reg_loss_unweighted: {type(rloss_unweight).__name__}, shape: {rloss_unweight.shape}, value: {rloss_unweight.item():.6f}" + ) + print(f" 4. reg_loss_dict: {type(rloss_dict).__name__}") + for k, v in rloss_dict.items(): + if isinstance(v, torch.Tensor): + print(f" - '{k}': {v.item():.6f}") + else: + print(f" - '{k}': {v}") + print( + f" 5. pred_loss (ploss): {type(ploss).__name__}, shape: {ploss.shape}, value: {ploss.item():.6f}" + ) + + # Assertions + assert loss.shape == torch.Size( + [] + ), f"total_loss should be scalar, got {loss.shape}" + assert rloss.shape == torch.Size( + [] + ), f"reg_loss should be scalar, got {rloss.shape}" + assert ploss.shape == torch.Size( + [] + ), f"pred_loss should be scalar, got {ploss.shape}" + print("\n βœ“ All assertions passed!") + + print("=" * 60) + + return loss, rloss, rloss_unweight, rloss_dict, ploss + + +def test_infer_method(): + """Test the infer() method which uses unroll() internally.""" + print("\n" + "=" * 60) + print("Testing infer() method") + print("=" * 60) + + device = "cuda" if torch.cuda.is_available() else "cpu" + jepa, _ = create_video_jepa_model(device) + jepa.eval() + + B, C, T, H, W = 2, 1, 8, 64, 64 + x = torch.randn(B, C, T, H, W, device=device) + + with torch.no_grad(): + # infer() is defined as: unroll(..., nsteps=1, return_all_steps=True)[0] + infer_result = jepa.infer(x, actions=None) + + print(f"infer() output shape: {infer_result.shape}") + print(f" βœ“ infer() returns single tensor (first step from unroll)") + + print("=" * 60) + + +# ============================================================================ +# Tests for unroll() function in autoregressive mode +# ============================================================================ + + +def test_unroll_autoregressive_mode_shapes(): + """ + Test unroll() input and output tensor shapes in autoregressive mode. + + This tests the autoregressive mode as used in planning/MPC. + + Usage pattern: + predicted_states, _ = jepa.unroll(obs_init, actions, nsteps, + unroll_mode="autoregressive", + ctxt_window_time=1, compute_loss=False) + """ + print("\n" + "=" * 60) + print("Testing AC Video JEPA unroll() autoregressive mode shapes") + print("=" * 60) + + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + # Create AC Video JEPA model + img_size = 65 # Default for two_rooms + jepa, config = create_ac_video_jepa_model(device, img_size=img_size) + jepa.eval() + + print("\n--- Model Configuration (from train.yaml) ---") + for k, v in config.items(): + print(f" {k}: {v}") + + # Input dimensions matching two_rooms format + B = 4 # Batch size + C = config["dobs"] # Input channels + H = W = config["img_size"] + A = config["action_dim"] + D = config["mlp_output_dim"] # Encoder output dimension + + # Test case 1: Single initial frame (as used in planning) + print("\n" + "-" * 60) + print("Test Case 1: Single initial frame (planning pattern)") + print("-" * 60) + + T_context = 1 # Single context frame + nsteps = 10 # Number of prediction steps + T_actions = nsteps # Actions for all prediction steps + + # Input: single initial observation + obs_init = torch.randn(B, C, T_context, H, W, device=device) + # Input: action sequence for unrolling + actions = torch.randn(B, A, T_actions, device=device) + + print(f"\n--- Input Shapes ---") + print(f" obs_init: [{B}, {C}, {T_context}, {H}, {W}] (initial observation)") + print(f" actions: [{B}, {A}, {T_actions}] (action sequence)") + print(f" nsteps: {nsteps}") + + print(f"\nCalling: jepa.unroll(obs_init, actions, nsteps={nsteps}, ...") + print( + f" unroll_mode='autoregressive', ctxt_window_time=1, compute_loss=False)" + ) + + with torch.no_grad(): + predicted_states, losses = jepa.unroll( + obs_init, + actions, + nsteps=nsteps, + unroll_mode="autoregressive", + ctxt_window_time=1, + compute_loss=False, + ) + + print(f"\n--- unroll() Output Analysis ---") + print(f" predicted_states shape: {predicted_states.shape}") + print(f" losses: {losses} (expected: None when compute_loss=False)") + + # For RNN predictor (single_unroll=True), output is [B, D, 1 + nsteps, H', W'] + # The RNN predictor only uses the first frame (state[:, :, :1]) as initial state + expected_T_out = 1 + nsteps # First encoded frame + nsteps predictions + expected_shape = (B, D, expected_T_out, 1, 1) + + print(f"\n--- Shape Assertions ---") + assert ( + predicted_states.shape[0] == B + ), f"Batch dim mismatch: {predicted_states.shape[0]} vs {B}" + print(f" βœ“ Batch dimension: {B}") + + assert ( + predicted_states.shape[1] == D + ), f"Feature dim mismatch: {predicted_states.shape[1]} vs {D}" + print(f" βœ“ Feature dimension: {D}") + + assert ( + predicted_states.shape[2] == expected_T_out + ), f"Time dim mismatch: {predicted_states.shape[2]} vs {expected_T_out}" + print(f" βœ“ Time dimension: {expected_T_out} (1 + nsteps={nsteps})") + + assert losses is None, f"Expected losses=None, got {losses}" + print(f" βœ“ losses is None when compute_loss=False") + + # Test case 2: Verify encoder output is preserved in first timestep(s) + print("\n" + "-" * 60) + print("Test Case 2: Verify encoder output preserved at t=0") + print("-" * 60) + + with torch.no_grad(): + # Encode the initial observation + encoded_init = jepa.encoder(obs_init) # [B, D, T_context, 1, 1] + + print(f" Encoded initial obs shape: {encoded_init.shape}") + print( + f" First timestep of unroll output shape: {predicted_states[:, :, :T_context].shape}" + ) + + # The first timestep(s) should match the encoded initial observation + first_timesteps = predicted_states[:, :, :T_context] + assert torch.allclose( + first_timesteps, encoded_init, atol=1e-5 + ), "First timestep(s) of unroll should match encoded initial observation" + print(f" βœ“ First timestep matches encoded initial observation") + + # Test case 3: Verify error when nsteps > action sequence length + print("\n" + "-" * 60) + print("Test Case 3: Error handling for nsteps > action length") + print("-" * 60) + + short_actions = torch.randn(B, A, 5, device=device) # Only 5 actions + try: + with torch.no_grad(): + _ = jepa.unroll( + obs_init, + short_actions, + nsteps=10, + unroll_mode="autoregressive", + ctxt_window_time=1, + compute_loss=False, + ) # Request 10 steps + print(" βœ— Should have raised an error!") + assert False, "Expected ValueError for nsteps > action sequence length" + except ValueError as e: + print(f" βœ“ Correctly raised ValueError: {e}") + + print("\n" + "=" * 60) + print("AC Video JEPA unroll() autoregressive mode Shape Test Summary:") + print("=" * 60) + print(f" Input observations: [B, C, T_context, H, W]") + print(f" Input actions: [B, A, T_actions] where T_actions >= nsteps") + print(f" Output: [B, D, 1 + nsteps, H', W'] for RNN predictor") + print( + f" (only first frame of context is used as initial state)" + ) + print(f" Where:") + print(f" - B = batch size") + print(f" - D = encoder output dim ({D})") + print(f" - H', W' = spatial dims after encoder (1, 1 for ImpalaEncoder)") + print("=" * 60) + print(" βœ“ All shape assertions passed!") + print("=" * 60) + + return predicted_states + + +def test_unroll_autoregressive_with_loss(): + """ + Test unroll() autoregressive mode with loss computation for training. + + Usage pattern: + _, losses = jepa.unroll(x, actions, nsteps, + unroll_mode="autoregressive", ctxt_window_time=1, + compute_loss=True) + """ + print("\n" + "=" * 60) + print("Testing AC Video JEPA unroll() autoregressive mode with loss") + print("=" * 60) + + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + set_seed(42) + img_size = 65 + jepa, config = create_ac_video_jepa_model(device, img_size=img_size) + jepa.train() + + # Verify this is an RNN predictor + assert ( + jepa.single_unroll + ), "AC Video JEPA should have single_unroll=True (RNN predictor)" + print(" βœ“ Confirmed RNN predictor (single_unroll=True)") + + # Create test input + B = 2 + C = config["dobs"] + T = 12 + H = W = config["img_size"] + A = config["action_dim"] + + set_seed(42) + x = torch.randn(B, C, T, H, W, device=device) + actions = torch.randn(B, A, T, device=device) + + nsteps = 6 + print(f"\nInput shapes: observations={x.shape}, actions={actions.shape}") + print(f"nsteps: {nsteps}") + + # Call unroll with compute_loss=True + print(f"\nCalling: jepa.unroll(x, actions, nsteps={nsteps}, ...") + print( + f" unroll_mode='autoregressive', ctxt_window_time=1, compute_loss=True)" + ) + + predicted_states, losses = jepa.unroll( + x, + actions, + nsteps=nsteps, + unroll_mode="autoregressive", + ctxt_window_time=1, + compute_loss=True, + ) + loss, rloss, rloss_unweight, rloss_dict, ploss = losses + + print(f"\n--- unroll() Output Analysis ---") + print(f" predicted_states shape: {predicted_states.shape}") + print(f"\nlosses tuple contains 5 elements:") + print(f" 1. total_loss: shape={loss.shape}, dtype={loss.dtype}") + print(f" 2. reg_loss: shape={rloss.shape}, dtype={rloss.dtype}") + print( + f" 3. reg_loss_unweight: shape={rloss_unweight.shape}, dtype={rloss_unweight.dtype}" + ) + print(f" 4. reg_loss_dict: keys={list(rloss_dict.keys())}") + print(f" 5. pred_loss: shape={ploss.shape}, dtype={ploss.dtype}") + + # Assertions + assert loss.shape == torch.Size( + [] + ), f"total_loss should be scalar, got {loss.shape}" + assert rloss.shape == torch.Size( + [] + ), f"reg_loss should be scalar, got {rloss.shape}" + assert ploss.shape == torch.Size( + [] + ), f"pred_loss should be scalar, got {ploss.shape}" + + # reg_loss_dict should contain expected keys for VC_IDM_Sim_Regularizer + expected_keys = {"std_loss", "cov_loss", "sim_loss_t", "idm_loss"} + assert ( + set(rloss_dict.keys()) == expected_keys + ), f"Expected keys {expected_keys}, got {set(rloss_dict.keys())}" + print(f"\n βœ“ reg_loss_dict contains expected keys: {expected_keys}") + print(" βœ“ All assertions passed!") + + print("=" * 60) + + return loss, rloss, rloss_unweight, rloss_dict, ploss + + +def test_unroll_autoregressive_with_conv_predictor(): + """ + Test unroll() autoregressive mode with Conv predictor (non-RNN). + + This tests the sliding window behavior with the Video JEPA model, + which uses a ResUNet predictor that processes sliding windows. + """ + print("\n" + "=" * 60) + print("Testing unroll() autoregressive mode (Conv predictor)") + print("=" * 60) + + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + set_seed(42) + jepa, dstc = create_video_jepa_model(device) + jepa.eval() + + # Verify this is NOT an RNN predictor + assert ( + not jepa.single_unroll + ), "Video JEPA should have single_unroll=False (Conv predictor)" + print(" βœ“ Confirmed Conv predictor (single_unroll=False)") + + # Create test input + B, C, T_context, H, W = 2, 1, 3, 64, 64 + nsteps = 5 + ctxt_window_time = 2 + + set_seed(42) + obs = torch.randn(B, C, T_context, H, W, device=device) + + print(f"\nInput shape: {obs.shape}") + print(f"nsteps: {nsteps}") + print(f"ctxt_window_time: {ctxt_window_time}") + + print(f"\nCalling: jepa.unroll(obs, actions=None, nsteps={nsteps}, ...") + print( + f" unroll_mode='autoregressive', ctxt_window_time={ctxt_window_time})" + ) + + with torch.no_grad(): + unroll_result, unroll_losses = jepa.unroll( + obs, + actions=None, + nsteps=nsteps, + unroll_mode="autoregressive", + ctxt_window_time=ctxt_window_time, + compute_loss=False, + return_all_steps=False, + ) + + expected_T_out = ctxt_window_time + nsteps + print(f"\n Output shape: {unroll_result.shape}") + print(f" Expected time dimension: {expected_T_out} (ctxt_window_time + nsteps)") + + assert ( + unroll_result.shape[2] == expected_T_out + ), f"Time dim mismatch: got {unroll_result.shape[2]}, expected {expected_T_out}" + print(f" βœ“ Time dimension correct: {unroll_result.shape[2]}") + + print("\n" + "=" * 60) + print("unroll() autoregressive mode (Conv predictor) Test: PASSED") + print("=" * 60) + + return True + + +def test_unroll_return_all_steps_format(): + """ + Test that return_all_steps=True returns the correct format for both modes. + """ + print("\n" + "=" * 60) + print("Testing unroll() return_all_steps format") + print("=" * 60) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Test with Video JEPA (parallel mode) + print("\n--- Parallel mode (Video JEPA) ---") + set_seed(42) + jepa, dstc = create_video_jepa_model(device) + jepa.eval() + + B, C, T, H, W = 2, 1, 8, 64, 64 + x = torch.randn(B, C, T, H, W, device=device) + nsteps = 3 + + with torch.no_grad(): + all_steps, _ = jepa.unroll( + x, + actions=None, + nsteps=nsteps, + unroll_mode="parallel", + compute_loss=False, + return_all_steps=True, + ) + + assert isinstance(all_steps, list), f"Expected list, got {type(all_steps)}" + assert len(all_steps) == nsteps, f"Expected {nsteps} steps, got {len(all_steps)}" + print(f" βœ“ Parallel mode returns list of {len(all_steps)} tensors") + for i, step in enumerate(all_steps): + print(f" Step {i}: shape={step.shape}") + + # Test with AC Video JEPA (autoregressive mode) + print("\n--- Autoregressive mode (AC Video JEPA) ---") + set_seed(42) + jepa_ac, config = create_ac_video_jepa_model(device) + jepa_ac.eval() + + B = 2 + C = config["dobs"] + H = W = config["img_size"] + A = config["action_dim"] + + obs = torch.randn(B, C, 1, H, W, device=device) + actions = torch.randn(B, A, nsteps, device=device) + + with torch.no_grad(): + all_steps_ac, _ = jepa_ac.unroll( + obs, + actions, + nsteps=nsteps, + unroll_mode="autoregressive", + ctxt_window_time=1, + compute_loss=False, + return_all_steps=True, + ) + + assert isinstance(all_steps_ac, list), f"Expected list, got {type(all_steps_ac)}" + assert ( + len(all_steps_ac) == nsteps + ), f"Expected {nsteps} steps, got {len(all_steps_ac)}" + print(f" βœ“ Autoregressive mode returns list of {len(all_steps_ac)} tensors") + for i, step in enumerate(all_steps_ac): + print(f" Step {i}: shape={step.shape}") + + # Verify autoregressive steps grow in time dimension + for i in range(1, len(all_steps_ac)): + assert ( + all_steps_ac[i].shape[2] == all_steps_ac[i - 1].shape[2] + 1 + ), f"Autoregressive steps should grow by 1: step {i-1}={all_steps_ac[i-1].shape[2]}, step {i}={all_steps_ac[i].shape[2]}" + print(" βœ“ Autoregressive steps correctly grow in time dimension") + + print("\n" + "=" * 60) + print("unroll() return_all_steps format Test: PASSED") + print("=" * 60) + + return True + + +def run_all_tests(): + """Run all tests for unroll() function.""" + print("\n" + "#" * 60) + print("# UNROLL() FUNCTION TEST SUITE") + print("#" * 60) + + results = {} + + # Parallel mode tests + try: + test_unroll_parallel_mode_output_format() + results["unroll parallel mode output"] = "PASSED" + except AssertionError as e: + results["unroll parallel mode output"] = f"FAILED: {e}" + + try: + test_unroll_parallel_mode_with_loss() + results["unroll parallel mode with loss"] = "PASSED" + except AssertionError as e: + results["unroll parallel mode with loss"] = f"FAILED: {e}" + + try: + test_infer_method() + results["infer method"] = "PASSED" + except AssertionError as e: + results["infer method"] = f"FAILED: {e}" + + # Autoregressive mode tests + try: + test_unroll_autoregressive_mode_shapes() + results["unroll autoregressive mode shapes"] = "PASSED" + except AssertionError as e: + results["unroll autoregressive mode shapes"] = f"FAILED: {e}" + + try: + test_unroll_autoregressive_with_loss() + results["unroll autoregressive with loss"] = "PASSED" + except AssertionError as e: + results["unroll autoregressive with loss"] = f"FAILED: {e}" + + try: + test_unroll_autoregressive_with_conv_predictor() + results["unroll autoregressive (Conv)"] = "PASSED" + except AssertionError as e: + results["unroll autoregressive (Conv)"] = f"FAILED: {e}" + + try: + test_unroll_return_all_steps_format() + results["return_all_steps format"] = "PASSED" + except AssertionError as e: + results["return_all_steps format"] = f"FAILED: {e}" + + # Summary + print("\n" + "#" * 60) + print("# TEST SUMMARY") + print("#" * 60) + all_passed = True + for test_name, result in results.items(): + status = "βœ“" if result == "PASSED" else "βœ—" + print(f" {status} {test_name}: {result}") + if result != "PASSED": + all_passed = False + + return all_passed + + +if __name__ == "__main__": + print("\n" + "#" * 60) + print("# JEPA Output Format Test Suite") + print("#" * 60 + "\n") + + all_passed = run_all_tests() + + print("\n" + "#" * 60) + if all_passed: + print("# All tests completed successfully!") + else: + print("# Some tests FAILED - see details above") + print("#" * 60) diff --git a/tests/test_loss_equivalences.py b/tests/test_loss_equivalences.py new file mode 100644 index 0000000..1f00992 --- /dev/null +++ b/tests/test_loss_equivalences.py @@ -0,0 +1,320 @@ +""" +Unit tests to verify mathematical equivalences between loss implementations. + +Tests the following claims: +1. VICRegLoss std computation (without centering) equals HingeStdLoss (with centering) +2. VICRegLoss cov computation equals CovarianceLoss +3. VICRegLoss decomposes into HingeStdLoss + CovarianceLoss + MSE invariance +4. Centering before torch.var() is redundant +""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from eb_jepa.losses import ( + CovarianceLoss, + HingeStdLoss, + VCLoss, + VICRegLoss, +) + + +class TestStdLossEquivalence: + """Test equivalences for standard deviation / variance loss.""" + + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + torch.manual_seed(42) + return torch.randn(64, 128) # batch_size=64, features=128 + + def test_vicreg_std_vs_hinge_std(self, sample_data): + """ + Test if VICRegLoss std computation (without explicit centering) + equals HingeStdLoss (with centering). + + This tests the claim that centering before var() is redundant. + """ + x = sample_data + + # VICRegLoss style: no explicit centering, uses 1e-4 epsilon + z_std_vicreg = torch.sqrt(x.var(dim=0) + 1e-4) + std_loss_vicreg = torch.mean(F.relu(1 - z_std_vicreg)) + + # HingeStdLoss style: explicit centering, uses 0.0001 epsilon + x_centered = x - x.mean(dim=0, keepdim=True) + z_std_hinge = torch.sqrt(x_centered.var(dim=0) + 0.0001) + std_loss_hinge = torch.mean(F.relu(1 - z_std_hinge)) + + # Should be equal because: + # 1. torch.var() computes variance around the mean anyway + # 2. 1e-4 == 0.0001 + assert torch.allclose(std_loss_vicreg, std_loss_hinge, atol=1e-7), ( + f"VICReg std: {std_loss_vicreg.item():.8f} vs " + f"HingeStd: {std_loss_hinge.item():.8f}" + ) + + def test_centering_before_var_is_redundant(self, sample_data): + """ + Directly test that centering before torch.var() is redundant. + """ + x = sample_data + + # Without centering + var_no_center = x.var(dim=0) + + # With centering + x_centered = x - x.mean(dim=0, keepdim=True) + var_with_center = x_centered.var(dim=0) + + # Should be exactly equal + assert torch.allclose(var_no_center, var_with_center, atol=1e-7), ( + f"Var without centering and var with centering should be equal. " + f"Max diff: {(var_no_center - var_with_center).abs().max().item():.8e}" + ) + + +class TestCovLossEquivalence: + """Test equivalences for covariance loss.""" + + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + torch.manual_seed(42) + return torch.randn(64, 128) # batch_size=64, features=128 + + def test_vicreg_cov_vs_covariance_loss(self, sample_data): + """ + Verify that VICRegLoss cov computation equals CovarianceLoss. + """ + x = sample_data + batch_size = x.size(0) + + # VICRegLoss style computation + z_centered = x - x.mean(dim=0) + z_cov = torch.mm(z_centered.T, z_centered) / (batch_size - 1) + cov_loss_vicreg = (z_cov.pow(2).sum() - z_cov.diagonal().pow(2).sum()) / ( + z_cov.size(0) ** 2 - z_cov.size(0) + ) + + # CovarianceLoss style + cov_loss_fn = CovarianceLoss() + cov_loss_class = cov_loss_fn(x) + + assert torch.allclose(cov_loss_vicreg, cov_loss_class, atol=1e-6), ( + f"VICReg cov: {cov_loss_vicreg.item():.8f} vs " + f"CovarianceLoss: {cov_loss_class.item():.8f}" + ) + + +class TestFullLossEquivalence: + """Test full loss function equivalences.""" + + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + torch.manual_seed(42) + return torch.randn(64, 128) + + def test_vicreg_decomposition(self, sample_data): + """ + Test that VICRegLoss can be decomposed into HingeStdLoss + CovarianceLoss + when applied to both views and summed. + """ + torch.manual_seed(42) + z1 = torch.randn(64, 128) + z2 = torch.randn(64, 128) + + std_coeff = 25.0 + cov_coeff = 1.0 + + # Using VICRegLoss + vicreg = VICRegLoss(std_coeff=std_coeff, cov_coeff=cov_coeff) + result = vicreg(z1, z2) + + # Manual decomposition using primitives + std_loss_fn = HingeStdLoss(std_margin=1.0) + cov_loss_fn = CovarianceLoss() + + sim_loss = F.mse_loss(z1, z2) + std_loss = std_loss_fn(z1) + std_loss_fn(z2) + cov_loss = cov_loss_fn(z1) + cov_loss_fn(z2) + + expected_total = sim_loss + std_coeff * std_loss + cov_coeff * cov_loss + + # Compare + assert torch.allclose( + result["invariance_loss"], sim_loss, atol=1e-6 + ), f"Sim loss: {result['invariance_loss'].item():.8f} vs {sim_loss.item():.8f}" + assert torch.allclose( + result["var_loss"], std_loss, atol=1e-6 + ), f"Var loss: {result['var_loss'].item():.8f} vs {std_loss.item():.8f}" + assert torch.allclose( + result["cov_loss"], cov_loss, atol=1e-6 + ), f"Cov loss: {result['cov_loss'].item():.8f} vs {cov_loss.item():.8f}" + assert torch.allclose( + result["loss"], expected_total, atol=1e-6 + ), f"Total loss: {result['loss'].item():.8f} vs {expected_total.item():.8f}" + + +class TestVCLoss: + """Test VCLoss functionality.""" + + def test_vc_loss_output_structure(self): + """Test that VCLoss produces correct output structure.""" + torch.manual_seed(42) + x_5d = torch.randn(8, 16, 2, 4, 4) # B=8, F=16, T=2, H=4, W=4 + + std_coeff = 10.0 + cov_coeff = 5.0 + + vc_loss = VCLoss(std_coeff=std_coeff, cov_coeff=cov_coeff) + loss, unweighted, loss_dict = vc_loss(x_5d) + + # Verify outputs are valid + assert not torch.isnan(loss), "Loss should not be NaN" + assert not torch.isnan(unweighted), "Unweighted loss should not be NaN" + assert loss.item() >= 0, "Loss should be non-negative" + assert "std_loss" in loss_dict + assert "cov_loss" in loss_dict + + def test_vc_loss_with_projector(self): + """Test that VCLoss works correctly with a projector.""" + torch.manual_seed(42) + x_5d = torch.randn(8, 16, 2, 4, 4) # B=8, F=16, T=2, H=4, W=4 + + # Create a simple projector + projector = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 32), + ) + + std_coeff = 10.0 + cov_coeff = 5.0 + + # Using VCLoss with projector + vc_loss = VCLoss(std_coeff=std_coeff, cov_coeff=cov_coeff, proj=projector) + loss, unweighted, loss_dict = vc_loss(x_5d) + + # Verify outputs are valid + assert not torch.isnan(loss), "Loss should not be NaN" + assert not torch.isnan(unweighted), "Unweighted loss should not be NaN" + assert loss.item() >= 0, "Loss should be non-negative" + assert "std_loss" in loss_dict + assert "cov_loss" in loss_dict + + def test_vc_loss_coefficient_weighting(self): + """Test that VCLoss correctly applies coefficient weighting.""" + torch.manual_seed(42) + x_5d = torch.randn(8, 16, 2, 4, 4) + + std_coeff = 10.0 + cov_coeff = 5.0 + + vc_loss = VCLoss(std_coeff=std_coeff, cov_coeff=cov_coeff) + loss, unweighted, loss_dict = vc_loss(x_5d) + + # Verify that weighted loss equals coefficient-weighted components + expected_loss = ( + std_coeff * loss_dict["std_loss"] + cov_coeff * loss_dict["cov_loss"] + ) + assert torch.allclose(loss, torch.tensor(expected_loss), atol=1e-5), ( + f"Weighted loss {loss.item():.6f} should equal " + f"{std_coeff}*{loss_dict['std_loss']:.6f} + {cov_coeff}*{loss_dict['cov_loss']:.6f} = {expected_loss:.6f}" + ) + + def test_vc_loss_consistency_across_seeds(self): + """Test that VCLoss is deterministic given same seed.""" + std_coeff = 10.0 + cov_coeff = 5.0 + + for seed in [1, 42, 100, 1000]: + # First run + torch.manual_seed(seed) + x_5d = torch.randn(8, 16, 2, 4, 4) + vc_loss = VCLoss(std_coeff=std_coeff, cov_coeff=cov_coeff) + loss1, _, _ = vc_loss(x_5d) + + # Second run with same seed + torch.manual_seed(seed) + x_5d = torch.randn(8, 16, 2, 4, 4) + vc_loss = VCLoss(std_coeff=std_coeff, cov_coeff=cov_coeff) + loss2, _, _ = vc_loss(x_5d) + + assert torch.allclose( + loss1, loss2, atol=1e-7 + ), f"Seed {seed}: loss should be deterministic, got {loss1.item():.6f} vs {loss2.item():.6f}" + + +class TestVICRegLossRegression: + """ + Regression tests to ensure VICRegLoss using HingeStdLoss + CovarianceLoss + produces identical results to the original inline implementation. + """ + + def test_refactored_vicreg_vs_original_implementation(self): + """ + Regression test: Refactored VICRegLoss using HingeStdLoss + CovarianceLoss + produces identical results to the original inline implementation. + """ + torch.manual_seed(42) + z1 = torch.randn(64, 128) + z2 = torch.randn(64, 128) + + std_coeff = 25.0 + cov_coeff = 1.0 + + # Using refactored VICRegLoss + vicreg = VICRegLoss(std_coeff=std_coeff, cov_coeff=cov_coeff) + result = vicreg(z1, z2) + + # Original VICRegLoss implementation (inlined for comparison) + batch_size = z1.size(0) + + # Original invariance loss + sim_loss_orig = F.mse_loss(z1, z2) + + # Original variance loss + z1_std = torch.sqrt(z1.var(dim=0) + 1e-4) + z2_std = torch.sqrt(z2.var(dim=0) + 1e-4) + var_loss_orig = torch.mean(F.relu(1 - z1_std)) + torch.mean(F.relu(1 - z2_std)) + + # Original covariance loss + z1_centered = z1 - z1.mean(dim=0) + z2_centered = z2 - z2.mean(dim=0) + z1_cov = torch.mm(z1_centered.T, z1_centered) / (batch_size - 1) + z2_cov = torch.mm(z2_centered.T, z2_centered) / (batch_size - 1) + cov_loss_orig = (z1_cov.pow(2).sum() - z1_cov.diagonal().pow(2).sum()) / ( + z1_cov.size(0) ** 2 - z1_cov.size(0) + ) + (z2_cov.pow(2).sum() - z2_cov.diagonal().pow(2).sum()) / ( + z2_cov.size(0) ** 2 - z2_cov.size(0) + ) + + total_loss_orig = ( + sim_loss_orig + std_coeff * var_loss_orig + cov_coeff * cov_loss_orig + ) + + # Verify all components match + assert torch.allclose(result["invariance_loss"], sim_loss_orig, atol=1e-6), ( + f"Refactored sim: {result['invariance_loss'].item():.8f} vs " + f"Original sim: {sim_loss_orig.item():.8f}" + ) + assert torch.allclose(result["var_loss"], var_loss_orig, atol=1e-6), ( + f"Refactored var: {result['var_loss'].item():.8f} vs " + f"Original var: {var_loss_orig.item():.8f}" + ) + assert torch.allclose(result["cov_loss"], cov_loss_orig, atol=1e-6), ( + f"Refactored cov: {result['cov_loss'].item():.8f} vs " + f"Original cov: {cov_loss_orig.item():.8f}" + ) + assert torch.allclose(result["loss"], total_loss_orig, atol=1e-6), ( + f"Refactored total: {result['loss'].item():.8f} vs " + f"Original total: {total_loss_orig.item():.8f}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/uv.lock b/uv.lock deleted file mode 100644 index c8690a3..0000000 --- a/uv.lock +++ /dev/null @@ -1,1358 +0,0 @@ -version = 1 -revision = 3 -requires-python = "==3.12.*" -resolution-markers = [ - "sys_platform == 'linux'", - "sys_platform != 'emscripten' and sys_platform != 'linux'", - "sys_platform == 'emscripten'", -] - -[[package]] -name = "annotated-types" -version = "0.7.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, -] - -[[package]] -name = "antlr4-python3-runtime" -version = "4.9.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034, upload-time = "2021-11-06T17:52:23.524Z" } - -[[package]] -name = "black" -version = "25.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "mypy-extensions" }, - { name = "packaging" }, - { name = "pathspec" }, - { name = "platformdirs" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449, upload-time = "2025-01-29T04:15:40.373Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/83/71/3fe4741df7adf015ad8dfa082dd36c94ca86bb21f25608eb247b4afb15b2/black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b", size = 1650988, upload-time = "2025-01-29T05:37:16.707Z" }, - { url = "https://files.pythonhosted.org/packages/13/f3/89aac8a83d73937ccd39bbe8fc6ac8860c11cfa0af5b1c96d081facac844/black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc", size = 1453985, upload-time = "2025-01-29T05:37:18.273Z" }, - { url = "https://files.pythonhosted.org/packages/6f/22/b99efca33f1f3a1d2552c714b1e1b5ae92efac6c43e790ad539a163d1754/black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f", size = 1783816, upload-time = "2025-01-29T04:18:33.823Z" }, - { url = "https://files.pythonhosted.org/packages/18/7e/a27c3ad3822b6f2e0e00d63d58ff6299a99a5b3aee69fa77cd4b0076b261/black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba", size = 1440860, upload-time = "2025-01-29T04:19:12.944Z" }, - { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, -] - -[[package]] -name = "certifi" -version = "2025.8.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/dc/67/960ebe6bf230a96cda2e0abcf73af550ec4f090005363542f0765df162e0/certifi-2025.8.3.tar.gz", hash = "sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407", size = 162386, upload-time = "2025-08-03T03:07:47.08Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/48/1549795ba7742c948d2ad169c1c8cdbae65bc450d6cd753d124b17c8cd32/certifi-2025.8.3-py3-none-any.whl", hash = "sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5", size = 161216, upload-time = "2025-08-03T03:07:45.777Z" }, -] - -[[package]] -name = "cffi" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pycparser", marker = "implementation_name != 'PyPy'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/47/4f61023ea636104d4f16ab488e268b93008c3d0bb76893b1b31db1f96802/cffi-2.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d", size = 185271, upload-time = "2025-09-08T23:22:44.795Z" }, - { url = "https://files.pythonhosted.org/packages/df/a2/781b623f57358e360d62cdd7a8c681f074a71d445418a776eef0aadb4ab4/cffi-2.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c", size = 181048, upload-time = "2025-09-08T23:22:45.938Z" }, - { url = "https://files.pythonhosted.org/packages/ff/df/a4f0fbd47331ceeba3d37c2e51e9dfc9722498becbeec2bd8bc856c9538a/cffi-2.0.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe", size = 212529, upload-time = "2025-09-08T23:22:47.349Z" }, - { url = "https://files.pythonhosted.org/packages/d5/72/12b5f8d3865bf0f87cf1404d8c374e7487dcf097a1c91c436e72e6badd83/cffi-2.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062", size = 220097, upload-time = "2025-09-08T23:22:48.677Z" }, - { url = "https://files.pythonhosted.org/packages/c2/95/7a135d52a50dfa7c882ab0ac17e8dc11cec9d55d2c18dda414c051c5e69e/cffi-2.0.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e", size = 207983, upload-time = "2025-09-08T23:22:50.06Z" }, - { url = "https://files.pythonhosted.org/packages/3a/c8/15cb9ada8895957ea171c62dc78ff3e99159ee7adb13c0123c001a2546c1/cffi-2.0.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037", size = 206519, upload-time = "2025-09-08T23:22:51.364Z" }, - { url = "https://files.pythonhosted.org/packages/78/2d/7fa73dfa841b5ac06c7b8855cfc18622132e365f5b81d02230333ff26e9e/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba", size = 219572, upload-time = "2025-09-08T23:22:52.902Z" }, - { url = "https://files.pythonhosted.org/packages/07/e0/267e57e387b4ca276b90f0434ff88b2c2241ad72b16d31836adddfd6031b/cffi-2.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94", size = 222963, upload-time = "2025-09-08T23:22:54.518Z" }, - { url = "https://files.pythonhosted.org/packages/b6/75/1f2747525e06f53efbd878f4d03bac5b859cbc11c633d0fb81432d98a795/cffi-2.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187", size = 221361, upload-time = "2025-09-08T23:22:55.867Z" }, - { url = "https://files.pythonhosted.org/packages/7b/2b/2b6435f76bfeb6bbf055596976da087377ede68df465419d192acf00c437/cffi-2.0.0-cp312-cp312-win32.whl", hash = "sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18", size = 172932, upload-time = "2025-09-08T23:22:57.188Z" }, - { url = "https://files.pythonhosted.org/packages/f8/ed/13bd4418627013bec4ed6e54283b1959cf6db888048c7cf4b4c3b5b36002/cffi-2.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5", size = 183557, upload-time = "2025-09-08T23:22:58.351Z" }, - { url = "https://files.pythonhosted.org/packages/95/31/9f7f93ad2f8eff1dbc1c3656d7ca5bfd8fb52c9d786b4dcf19b2d02217fa/cffi-2.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6", size = 177762, upload-time = "2025-09-08T23:22:59.668Z" }, -] - -[[package]] -name = "charset-normalizer" -version = "3.4.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/83/2d/5fd176ceb9b2fc619e63405525573493ca23441330fcdaee6bef9460e924/charset_normalizer-3.4.3.tar.gz", hash = "sha256:6fce4b8500244f6fcb71465d4a4930d132ba9ab8e71a7859e6a5d59851068d14", size = 122371, upload-time = "2025-08-09T07:57:28.46Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/5e/14c94999e418d9b87682734589404a25854d5f5d0408df68bc15b6ff54bb/charset_normalizer-3.4.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e28e334d3ff134e88989d90ba04b47d84382a828c061d0d1027b1b12a62b39b1", size = 205655, upload-time = "2025-08-09T07:56:08.475Z" }, - { url = "https://files.pythonhosted.org/packages/7d/a8/c6ec5d389672521f644505a257f50544c074cf5fc292d5390331cd6fc9c3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0cacf8f7297b0c4fcb74227692ca46b4a5852f8f4f24b3c766dd94a1075c4884", size = 146223, upload-time = "2025-08-09T07:56:09.708Z" }, - { url = "https://files.pythonhosted.org/packages/fc/eb/a2ffb08547f4e1e5415fb69eb7db25932c52a52bed371429648db4d84fb1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c6fd51128a41297f5409deab284fecbe5305ebd7e5a1f959bee1c054622b7018", size = 159366, upload-time = "2025-08-09T07:56:11.326Z" }, - { url = "https://files.pythonhosted.org/packages/82/10/0fd19f20c624b278dddaf83b8464dcddc2456cb4b02bb902a6da126b87a1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3cfb2aad70f2c6debfbcb717f23b7eb55febc0bb23dcffc0f076009da10c6392", size = 157104, upload-time = "2025-08-09T07:56:13.014Z" }, - { url = "https://files.pythonhosted.org/packages/16/ab/0233c3231af734f5dfcf0844aa9582d5a1466c985bbed6cedab85af9bfe3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1606f4a55c0fd363d754049cdf400175ee96c992b1f8018b993941f221221c5f", size = 151830, upload-time = "2025-08-09T07:56:14.428Z" }, - { url = "https://files.pythonhosted.org/packages/ae/02/e29e22b4e02839a0e4a06557b1999d0a47db3567e82989b5bb21f3fbbd9f/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:027b776c26d38b7f15b26a5da1044f376455fb3766df8fc38563b4efbc515154", size = 148854, upload-time = "2025-08-09T07:56:16.051Z" }, - { url = "https://files.pythonhosted.org/packages/05/6b/e2539a0a4be302b481e8cafb5af8792da8093b486885a1ae4d15d452bcec/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:42e5088973e56e31e4fa58eb6bd709e42fc03799c11c42929592889a2e54c491", size = 160670, upload-time = "2025-08-09T07:56:17.314Z" }, - { url = "https://files.pythonhosted.org/packages/31/e7/883ee5676a2ef217a40ce0bffcc3d0dfbf9e64cbcfbdf822c52981c3304b/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cc34f233c9e71701040d772aa7490318673aa7164a0efe3172b2981218c26d93", size = 158501, upload-time = "2025-08-09T07:56:18.641Z" }, - { url = "https://files.pythonhosted.org/packages/c1/35/6525b21aa0db614cf8b5792d232021dca3df7f90a1944db934efa5d20bb1/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:320e8e66157cc4e247d9ddca8e21f427efc7a04bbd0ac8a9faf56583fa543f9f", size = 153173, upload-time = "2025-08-09T07:56:20.289Z" }, - { url = "https://files.pythonhosted.org/packages/50/ee/f4704bad8201de513fdc8aac1cabc87e38c5818c93857140e06e772b5892/charset_normalizer-3.4.3-cp312-cp312-win32.whl", hash = "sha256:fb6fecfd65564f208cbf0fba07f107fb661bcd1a7c389edbced3f7a493f70e37", size = 99822, upload-time = "2025-08-09T07:56:21.551Z" }, - { url = "https://files.pythonhosted.org/packages/39/f5/3b3836ca6064d0992c58c7561c6b6eee1b3892e9665d650c803bd5614522/charset_normalizer-3.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:86df271bf921c2ee3818f0522e9a5b8092ca2ad8b065ece5d7d9d0e9f4849bcc", size = 107543, upload-time = "2025-08-09T07:56:23.115Z" }, - { url = "https://files.pythonhosted.org/packages/8a/1f/f041989e93b001bc4e44bb1669ccdcf54d3f00e628229a85b08d330615c5/charset_normalizer-3.4.3-py3-none-any.whl", hash = "sha256:ce571ab16d890d23b5c278547ba694193a45011ff86a9162a71307ed9f86759a", size = 53175, upload-time = "2025-08-09T07:57:26.864Z" }, -] - -[[package]] -name = "click" -version = "8.2.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, -] - -[[package]] -name = "cloudpickle" -version = "3.1.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/39/069100b84d7418bc358d81669d5748efb14b9cceacd2f9c75f550424132f/cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64", size = 22113, upload-time = "2025-01-14T17:02:05.085Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/e8/64c37fadfc2816a7701fa8a6ed8d87327c7d54eacfbfb6edab14a2f2be75/cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e", size = 20992, upload-time = "2025-01-14T17:02:02.417Z" }, -] - -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, -] - -[[package]] -name = "contourpy" -version = "1.3.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/58/01/1253e6698a07380cd31a736d248a3f2a50a7c88779a1813da27503cadc2a/contourpy-1.3.3.tar.gz", hash = "sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880", size = 13466174, upload-time = "2025-07-26T12:03:12.549Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/be/45/adfee365d9ea3d853550b2e735f9d66366701c65db7855cd07621732ccfc/contourpy-1.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b08a32ea2f8e42cf1d4be3169a98dd4be32bafe4f22b6c4cb4ba810fa9e5d2cb", size = 293419, upload-time = "2025-07-26T12:01:21.16Z" }, - { url = "https://files.pythonhosted.org/packages/53/3e/405b59cfa13021a56bba395a6b3aca8cec012b45bf177b0eaf7a202cde2c/contourpy-1.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:556dba8fb6f5d8742f2923fe9457dbdd51e1049c4a43fd3986a0b14a1d815fc6", size = 273979, upload-time = "2025-07-26T12:01:22.448Z" }, - { url = "https://files.pythonhosted.org/packages/d4/1c/a12359b9b2ca3a845e8f7f9ac08bdf776114eb931392fcad91743e2ea17b/contourpy-1.3.3-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:92d9abc807cf7d0e047b95ca5d957cf4792fcd04e920ca70d48add15c1a90ea7", size = 332653, upload-time = "2025-07-26T12:01:24.155Z" }, - { url = "https://files.pythonhosted.org/packages/63/12/897aeebfb475b7748ea67b61e045accdfcf0d971f8a588b67108ed7f5512/contourpy-1.3.3-cp312-cp312-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b2e8faa0ed68cb29af51edd8e24798bb661eac3bd9f65420c1887b6ca89987c8", size = 379536, upload-time = "2025-07-26T12:01:25.91Z" }, - { url = "https://files.pythonhosted.org/packages/43/8a/a8c584b82deb248930ce069e71576fc09bd7174bbd35183b7943fb1064fd/contourpy-1.3.3-cp312-cp312-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:626d60935cf668e70a5ce6ff184fd713e9683fb458898e4249b63be9e28286ea", size = 384397, upload-time = "2025-07-26T12:01:27.152Z" }, - { url = "https://files.pythonhosted.org/packages/cc/8f/ec6289987824b29529d0dfda0d74a07cec60e54b9c92f3c9da4c0ac732de/contourpy-1.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d00e655fcef08aba35ec9610536bfe90267d7ab5ba944f7032549c55a146da1", size = 362601, upload-time = "2025-07-26T12:01:28.808Z" }, - { url = "https://files.pythonhosted.org/packages/05/0a/a3fe3be3ee2dceb3e615ebb4df97ae6f3828aa915d3e10549ce016302bd1/contourpy-1.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:451e71b5a7d597379ef572de31eeb909a87246974d960049a9848c3bc6c41bf7", size = 1331288, upload-time = "2025-07-26T12:01:31.198Z" }, - { url = "https://files.pythonhosted.org/packages/33/1d/acad9bd4e97f13f3e2b18a3977fe1b4a37ecf3d38d815333980c6c72e963/contourpy-1.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:459c1f020cd59fcfe6650180678a9993932d80d44ccde1fa1868977438f0b411", size = 1403386, upload-time = "2025-07-26T12:01:33.947Z" }, - { url = "https://files.pythonhosted.org/packages/cf/8f/5847f44a7fddf859704217a99a23a4f6417b10e5ab1256a179264561540e/contourpy-1.3.3-cp312-cp312-win32.whl", hash = "sha256:023b44101dfe49d7d53932be418477dba359649246075c996866106da069af69", size = 185018, upload-time = "2025-07-26T12:01:35.64Z" }, - { url = "https://files.pythonhosted.org/packages/19/e8/6026ed58a64563186a9ee3f29f41261fd1828f527dd93d33b60feca63352/contourpy-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:8153b8bfc11e1e4d75bcb0bff1db232f9e10b274e0929de9d608027e0d34ff8b", size = 226567, upload-time = "2025-07-26T12:01:36.804Z" }, - { url = "https://files.pythonhosted.org/packages/d1/e2/f05240d2c39a1ed228d8328a78b6f44cd695f7ef47beb3e684cf93604f86/contourpy-1.3.3-cp312-cp312-win_arm64.whl", hash = "sha256:07ce5ed73ecdc4a03ffe3e1b3e3c1166db35ae7584be76f65dbbe28a7791b0cc", size = 193655, upload-time = "2025-07-26T12:01:37.999Z" }, -] - -[[package]] -name = "cycler" -version = "0.12.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a9/95/a3dbbb5028f35eafb79008e7522a75244477d2838f38cbb722248dabc2a8/cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c", size = 7615, upload-time = "2023-10-07T05:32:18.335Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321, upload-time = "2023-10-07T05:32:16.783Z" }, -] - -[[package]] -name = "decord" -version = "0.6.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, - { url = "https://files.pythonhosted.org/packages/6c/be/e15b5b866da452e62635a7b27513f31cb581fa2ea9cc9b768b535d62a955/decord-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02665d7c4f1193a330205a791bc128f7e108eb6ae5b67144437a02f700943bad", size = 24733380, upload-time = "2021-06-14T21:30:57.766Z" }, -] - -[[package]] -name = "eb-jepa" -version = "0.1.1" -source = { editable = "." } -dependencies = [ - { name = "decord" }, - { name = "einops" }, - { name = "fire" }, - { name = "gymnasium" }, - { name = "huggingface-hub" }, - { name = "imageio" }, - { name = "matplotlib" }, - { name = "omegaconf" }, - { name = "opencv-python" }, - { name = "pudb" }, - { name = "pymunk" }, - { name = "ruamel-yaml" }, - { name = "scikit-learn" }, - { name = "seaborn" }, - { name = "submitit" }, - { name = "tiktoken" }, - { name = "torch" }, - { name = "torchcodec" }, - { name = "torchvision" }, - { name = "tqdm" }, - { name = "wandb" }, -] - -[package.dev-dependencies] -dev = [ - { name = "black" }, - { name = "isort" }, - { name = "pytest" }, -] - -[package.metadata] -requires-dist = [ - { name = "decord" }, - { name = "einops", specifier = ">=0.8.1" }, - { name = "fire", specifier = ">=0.7.0" }, - { name = "gymnasium", specifier = ">=1.1.1" }, - { name = "huggingface-hub", specifier = ">=0.33.2" }, - { name = "imageio" }, - { name = "matplotlib", specifier = ">=3.10.3" }, - { name = "omegaconf" }, - { name = "opencv-python", specifier = ">=4.12.0.88" }, - { name = "pudb", specifier = ">=2025.1" }, - { name = "pymunk" }, - { name = "ruamel-yaml" }, - { name = "scikit-learn", specifier = ">=1.7.0" }, - { name = "seaborn" }, - { name = "submitit" }, - { name = "tiktoken", specifier = ">=0.9.0" }, - { name = "torch", specifier = "==2.6.0" }, - { name = "torchcodec", specifier = ">=0.4.0" }, - { name = "torchvision" }, - { name = "tqdm" }, - { name = "wandb", specifier = ">=0.21.1" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "black", specifier = ">=25.1.0" }, - { name = "isort", specifier = ">=6.0.1" }, - { name = "pytest", specifier = ">=8.4.1" }, -] - -[[package]] -name = "einops" -version = "0.8.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e5/81/df4fbe24dff8ba3934af99044188e20a98ed441ad17a274539b74e82e126/einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84", size = 54805, upload-time = "2025-02-09T03:17:00.434Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/62/9773de14fe6c45c23649e98b83231fffd7b9892b6cf863251dc2afa73643/einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737", size = 64359, upload-time = "2025-02-09T03:17:01.998Z" }, -] - -[[package]] -name = "farama-notifications" -version = "0.0.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2e/2c/8384832b7a6b1fd6ba95bbdcae26e7137bb3eedc955c42fd5cdcc086cfbf/Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18", size = 2131, upload-time = "2023-02-27T18:28:41.047Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/05/2c/ffc08c54c05cdce6fbed2aeebc46348dbe180c6d2c541c7af7ba0aa5f5f8/Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae", size = 2511, upload-time = "2023-02-27T18:28:39.447Z" }, -] - -[[package]] -name = "filelock" -version = "3.19.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/40/bb/0ab3e58d22305b6f5440629d20683af28959bf793d98d11950e305c1c326/filelock-3.19.1.tar.gz", hash = "sha256:66eda1888b0171c998b35be2bcc0f6d75c388a7ce20c3f3f37aa8e96c2dddf58", size = 17687, upload-time = "2025-08-14T16:56:03.016Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/42/14/42b2651a2f46b022ccd948bca9f2d5af0fd8929c4eec235b8d6d844fbe67/filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d", size = 15988, upload-time = "2025-08-14T16:56:01.633Z" }, -] - -[[package]] -name = "fire" -version = "0.7.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "termcolor" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c0/00/f8d10588d2019d6d6452653def1ee807353b21983db48550318424b5ff18/fire-0.7.1.tar.gz", hash = "sha256:3b208f05c736de98fb343310d090dcc4d8c78b2a89ea4f32b837c586270a9cbf", size = 88720, upload-time = "2025-08-16T20:20:24.175Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/4c/93d0f85318da65923e4b91c1c2ff03d8a458cbefebe3bc612a6693c7906d/fire-0.7.1-py3-none-any.whl", hash = "sha256:e43fd8a5033a9001e7e2973bab96070694b9f12f2e0ecf96d4683971b5ab1882", size = 115945, upload-time = "2025-08-16T20:20:22.87Z" }, -] - -[[package]] -name = "fonttools" -version = "4.59.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0d/a5/fba25f9fbdab96e26dedcaeeba125e5f05a09043bf888e0305326e55685b/fonttools-4.59.2.tar.gz", hash = "sha256:e72c0749b06113f50bcb80332364c6be83a9582d6e3db3fe0b280f996dc2ef22", size = 3540889, upload-time = "2025-08-27T16:40:30.97Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ba/3d/1f45db2df51e7bfa55492e8f23f383d372200be3a0ded4bf56a92753dd1f/fonttools-4.59.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:82906d002c349cad647a7634b004825a7335f8159d0d035ae89253b4abf6f3ea", size = 2769711, upload-time = "2025-08-27T16:39:04.423Z" }, - { url = "https://files.pythonhosted.org/packages/29/df/cd236ab32a8abfd11558f296e064424258db5edefd1279ffdbcfd4fd8b76/fonttools-4.59.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a10c1bd7644dc58f8862d8ba0cf9fb7fef0af01ea184ba6ce3f50ab7dfe74d5a", size = 2340225, upload-time = "2025-08-27T16:39:06.143Z" }, - { url = "https://files.pythonhosted.org/packages/98/12/b6f9f964fe6d4b4dd4406bcbd3328821c3de1f909ffc3ffa558fe72af48c/fonttools-4.59.2-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:738f31f23e0339785fd67652a94bc69ea49e413dfdb14dcb8c8ff383d249464e", size = 4912766, upload-time = "2025-08-27T16:39:08.138Z" }, - { url = "https://files.pythonhosted.org/packages/73/78/82bde2f2d2c306ef3909b927363170b83df96171f74e0ccb47ad344563cd/fonttools-4.59.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ec99f9bdfee9cdb4a9172f9e8fd578cce5feb231f598909e0aecf5418da4f25", size = 4955178, upload-time = "2025-08-27T16:39:10.094Z" }, - { url = "https://files.pythonhosted.org/packages/92/77/7de766afe2d31dda8ee46d7e479f35c7d48747e558961489a2d6e3a02bd4/fonttools-4.59.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0476ea74161322e08c7a982f83558a2b81b491509984523a1a540baf8611cc31", size = 4897898, upload-time = "2025-08-27T16:39:12.087Z" }, - { url = "https://files.pythonhosted.org/packages/c5/77/ce0e0b905d62a06415fda9f2b2e109a24a5db54a59502b769e9e297d2242/fonttools-4.59.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:95922a922daa1f77cc72611747c156cfb38030ead72436a2c551d30ecef519b9", size = 5049144, upload-time = "2025-08-27T16:39:13.84Z" }, - { url = "https://files.pythonhosted.org/packages/d9/ea/870d93aefd23fff2e07cbeebdc332527868422a433c64062c09d4d5e7fe6/fonttools-4.59.2-cp312-cp312-win32.whl", hash = "sha256:39ad9612c6a622726a6a130e8ab15794558591f999673f1ee7d2f3d30f6a3e1c", size = 2206473, upload-time = "2025-08-27T16:39:15.854Z" }, - { url = "https://files.pythonhosted.org/packages/61/c4/e44bad000c4a4bb2e9ca11491d266e857df98ab6d7428441b173f0fe2517/fonttools-4.59.2-cp312-cp312-win_amd64.whl", hash = "sha256:980fd7388e461b19a881d35013fec32c713ffea1fc37aef2f77d11f332dfd7da", size = 2254706, upload-time = "2025-08-27T16:39:17.893Z" }, - { url = "https://files.pythonhosted.org/packages/65/a4/d2f7be3c86708912c02571db0b550121caab8cd88a3c0aacb9cfa15ea66e/fonttools-4.59.2-py3-none-any.whl", hash = "sha256:8bd0f759020e87bb5d323e6283914d9bf4ae35a7307dafb2cbd1e379e720ad37", size = 1132315, upload-time = "2025-08-27T16:40:28.984Z" }, -] - -[[package]] -name = "fsspec" -version = "2025.7.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8b/02/0835e6ab9cfc03916fe3f78c0956cfcdb6ff2669ffa6651065d5ebf7fc98/fsspec-2025.7.0.tar.gz", hash = "sha256:786120687ffa54b8283d942929540d8bc5ccfa820deb555a2b5d0ed2b737bf58", size = 304432, upload-time = "2025-07-15T16:05:21.19Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/e0/014d5d9d7a4564cf1c40b5039bc882db69fd881111e03ab3657ac0b218e2/fsspec-2025.7.0-py3-none-any.whl", hash = "sha256:8b012e39f63c7d5f10474de957f3ab793b47b45ae7d39f2fb735f8bbe25c0e21", size = 199597, upload-time = "2025-07-15T16:05:19.529Z" }, -] - -[[package]] -name = "gitdb" -version = "4.0.12" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "smmap" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, -] - -[[package]] -name = "gitpython" -version = "3.1.45" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "gitdb" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9a/c8/dd58967d119baab745caec2f9d853297cec1989ec1d63f677d3880632b88/gitpython-3.1.45.tar.gz", hash = "sha256:85b0ee964ceddf211c41b9f27a49086010a190fd8132a24e21f362a4b36a791c", size = 215076, upload-time = "2025-07-24T03:45:54.871Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77", size = 208168, upload-time = "2025-07-24T03:45:52.517Z" }, -] - -[[package]] -name = "gymnasium" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cloudpickle" }, - { name = "farama-notifications" }, - { name = "numpy" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fd/17/c2a0e15c2cd5a8e788389b280996db927b923410de676ec5c7b2695e9261/gymnasium-1.2.0.tar.gz", hash = "sha256:344e87561012558f603880baf264ebc97f8a5c997a957b0c9f910281145534b0", size = 821142, upload-time = "2025-06-27T08:21:20.262Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/e2/a111dbb8625af467ea4760a1373d6ef27aac3137931219902406ccc05423/gymnasium-1.2.0-py3-none-any.whl", hash = "sha256:fc4a1e4121a9464c29b4d7dc6ade3fbeaa36dea448682f5f71a6d2c17489ea76", size = 944301, upload-time = "2025-06-27T08:21:18.83Z" }, -] - -[[package]] -name = "hf-xet" -version = "1.1.9" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/23/0f/5b60fc28ee7f8cc17a5114a584fd6b86e11c3e0a6e142a7f97a161e9640a/hf_xet-1.1.9.tar.gz", hash = "sha256:c99073ce404462e909f1d5839b2d14a3827b8fe75ed8aed551ba6609c026c803", size = 484242, upload-time = "2025-08-27T23:05:19.441Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/de/12/56e1abb9a44cdef59a411fe8a8673313195711b5ecce27880eb9c8fa90bd/hf_xet-1.1.9-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:a3b6215f88638dd7a6ff82cb4e738dcbf3d863bf667997c093a3c990337d1160", size = 2762553, upload-time = "2025-08-27T23:05:15.153Z" }, - { url = "https://files.pythonhosted.org/packages/3a/e6/2d0d16890c5f21b862f5df3146519c182e7f0ae49b4b4bf2bd8a40d0b05e/hf_xet-1.1.9-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:9b486de7a64a66f9a172f4b3e0dfe79c9f0a93257c501296a2521a13495a698a", size = 2623216, upload-time = "2025-08-27T23:05:13.778Z" }, - { url = "https://files.pythonhosted.org/packages/81/42/7e6955cf0621e87491a1fb8cad755d5c2517803cea174229b0ec00ff0166/hf_xet-1.1.9-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4c5a840c2c4e6ec875ed13703a60e3523bc7f48031dfd750923b2a4d1a5fc3c", size = 3186789, upload-time = "2025-08-27T23:05:12.368Z" }, - { url = "https://files.pythonhosted.org/packages/df/8b/759233bce05457f5f7ec062d63bbfd2d0c740b816279eaaa54be92aa452a/hf_xet-1.1.9-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:96a6139c9e44dad1c52c52520db0fffe948f6bce487cfb9d69c125f254bb3790", size = 3088747, upload-time = "2025-08-27T23:05:10.439Z" }, - { url = "https://files.pythonhosted.org/packages/6c/3c/28cc4db153a7601a996985bcb564f7b8f5b9e1a706c7537aad4b4809f358/hf_xet-1.1.9-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ad1022e9a998e784c97b2173965d07fe33ee26e4594770b7785a8cc8f922cd95", size = 3251429, upload-time = "2025-08-27T23:05:16.471Z" }, - { url = "https://files.pythonhosted.org/packages/84/17/7caf27a1d101bfcb05be85850d4aa0a265b2e1acc2d4d52a48026ef1d299/hf_xet-1.1.9-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:86754c2d6d5afb11b0a435e6e18911a4199262fe77553f8c50d75e21242193ea", size = 3354643, upload-time = "2025-08-27T23:05:17.828Z" }, - { url = "https://files.pythonhosted.org/packages/cd/50/0c39c9eed3411deadcc98749a6699d871b822473f55fe472fad7c01ec588/hf_xet-1.1.9-cp37-abi3-win_amd64.whl", hash = "sha256:5aad3933de6b725d61d51034e04174ed1dce7a57c63d530df0014dea15a40127", size = 2804797, upload-time = "2025-08-27T23:05:20.77Z" }, -] - -[[package]] -name = "huggingface-hub" -version = "0.34.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "filelock" }, - { name = "fsspec" }, - { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, - { name = "packaging" }, - { name = "pyyaml" }, - { name = "requests" }, - { name = "tqdm" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/45/c9/bdbe19339f76d12985bc03572f330a01a93c04dffecaaea3061bdd7fb892/huggingface_hub-0.34.4.tar.gz", hash = "sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c", size = 459768, upload-time = "2025-08-08T09:14:52.365Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/39/7b/bb06b061991107cd8783f300adff3e7b7f284e330fd82f507f2a1417b11d/huggingface_hub-0.34.4-py3-none-any.whl", hash = "sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a", size = 561452, upload-time = "2025-08-08T09:14:50.159Z" }, -] - -[[package]] -name = "idna" -version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, -] - -[[package]] -name = "imageio" -version = "2.37.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, - { name = "pillow" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0c/47/57e897fb7094afb2d26e8b2e4af9a45c7cf1a405acdeeca001fdf2c98501/imageio-2.37.0.tar.gz", hash = "sha256:71b57b3669666272c818497aebba2b4c5f20d5b37c81720e5e1a56d59c492996", size = 389963, upload-time = "2025-01-20T02:42:37.089Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/bd/b394387b598ed84d8d0fa90611a90bee0adc2021820ad5729f7ced74a8e2/imageio-2.37.0-py3-none-any.whl", hash = "sha256:11efa15b87bc7871b61590326b2d635439acc321cf7f8ce996f812543ce10eed", size = 315796, upload-time = "2025-01-20T02:42:34.931Z" }, -] - -[[package]] -name = "iniconfig" -version = "2.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, -] - -[[package]] -name = "isort" -version = "6.0.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b8/21/1e2a441f74a653a144224d7d21afe8f4169e6c7c20bb13aec3a2dc3815e0/isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450", size = 821955, upload-time = "2025-02-26T21:13:16.955Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/11/114d0a5f4dabbdcedc1125dee0888514c3c3b16d3e9facad87ed96fad97c/isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615", size = 94186, upload-time = "2025-02-26T21:13:14.911Z" }, -] - -[[package]] -name = "jedi" -version = "0.19.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "parso" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" }, -] - -[[package]] -name = "jinja2" -version = "3.1.6" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "markupsafe" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, -] - -[[package]] -name = "joblib" -version = "1.5.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077, upload-time = "2025-08-27T12:15:46.575Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, -] - -[[package]] -name = "kiwisolver" -version = "1.4.9" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5c/3c/85844f1b0feb11ee581ac23fe5fce65cd049a200c1446708cc1b7f922875/kiwisolver-1.4.9.tar.gz", hash = "sha256:c3b22c26c6fd6811b0ae8363b95ca8ce4ea3c202d3d0975b2914310ceb1bcc4d", size = 97564, upload-time = "2025-08-10T21:27:49.279Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/86/c9/13573a747838aeb1c76e3267620daa054f4152444d1f3d1a2324b78255b5/kiwisolver-1.4.9-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ac5a486ac389dddcc5bef4f365b6ae3ffff2c433324fb38dd35e3fab7c957999", size = 123686, upload-time = "2025-08-10T21:26:10.034Z" }, - { url = "https://files.pythonhosted.org/packages/51/ea/2ecf727927f103ffd1739271ca19c424d0e65ea473fbaeea1c014aea93f6/kiwisolver-1.4.9-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f2ba92255faa7309d06fe44c3a4a97efe1c8d640c2a79a5ef728b685762a6fd2", size = 66460, upload-time = "2025-08-10T21:26:11.083Z" }, - { url = "https://files.pythonhosted.org/packages/5b/5a/51f5464373ce2aeb5194508298a508b6f21d3867f499556263c64c621914/kiwisolver-1.4.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a2899935e724dd1074cb568ce7ac0dce28b2cd6ab539c8e001a8578eb106d14", size = 64952, upload-time = "2025-08-10T21:26:12.058Z" }, - { url = "https://files.pythonhosted.org/packages/70/90/6d240beb0f24b74371762873e9b7f499f1e02166a2d9c5801f4dbf8fa12e/kiwisolver-1.4.9-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f6008a4919fdbc0b0097089f67a1eb55d950ed7e90ce2cc3e640abadd2757a04", size = 1474756, upload-time = "2025-08-10T21:26:13.096Z" }, - { url = "https://files.pythonhosted.org/packages/12/42/f36816eaf465220f683fb711efdd1bbf7a7005a2473d0e4ed421389bd26c/kiwisolver-1.4.9-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:67bb8b474b4181770f926f7b7d2f8c0248cbcb78b660fdd41a47054b28d2a752", size = 1276404, upload-time = "2025-08-10T21:26:14.457Z" }, - { url = "https://files.pythonhosted.org/packages/2e/64/bc2de94800adc830c476dce44e9b40fd0809cddeef1fde9fcf0f73da301f/kiwisolver-1.4.9-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2327a4a30d3ee07d2fbe2e7933e8a37c591663b96ce42a00bc67461a87d7df77", size = 1294410, upload-time = "2025-08-10T21:26:15.73Z" }, - { url = "https://files.pythonhosted.org/packages/5f/42/2dc82330a70aa8e55b6d395b11018045e58d0bb00834502bf11509f79091/kiwisolver-1.4.9-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:7a08b491ec91b1d5053ac177afe5290adacf1f0f6307d771ccac5de30592d198", size = 1343631, upload-time = "2025-08-10T21:26:17.045Z" }, - { url = "https://files.pythonhosted.org/packages/22/fd/f4c67a6ed1aab149ec5a8a401c323cee7a1cbe364381bb6c9c0d564e0e20/kiwisolver-1.4.9-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d8fc5c867c22b828001b6a38d2eaeb88160bf5783c6cb4a5e440efc981ce286d", size = 2224963, upload-time = "2025-08-10T21:26:18.737Z" }, - { url = "https://files.pythonhosted.org/packages/45/aa/76720bd4cb3713314677d9ec94dcc21ced3f1baf4830adde5bb9b2430a5f/kiwisolver-1.4.9-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:3b3115b2581ea35bb6d1f24a4c90af37e5d9b49dcff267eeed14c3893c5b86ab", size = 2321295, upload-time = "2025-08-10T21:26:20.11Z" }, - { url = "https://files.pythonhosted.org/packages/80/19/d3ec0d9ab711242f56ae0dc2fc5d70e298bb4a1f9dfab44c027668c673a1/kiwisolver-1.4.9-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:858e4c22fb075920b96a291928cb7dea5644e94c0ee4fcd5af7e865655e4ccf2", size = 2487987, upload-time = "2025-08-10T21:26:21.49Z" }, - { url = "https://files.pythonhosted.org/packages/39/e9/61e4813b2c97e86b6fdbd4dd824bf72d28bcd8d4849b8084a357bc0dd64d/kiwisolver-1.4.9-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ed0fecd28cc62c54b262e3736f8bb2512d8dcfdc2bcf08be5f47f96bf405b145", size = 2291817, upload-time = "2025-08-10T21:26:22.812Z" }, - { url = "https://files.pythonhosted.org/packages/a0/41/85d82b0291db7504da3c2defe35c9a8a5c9803a730f297bd823d11d5fb77/kiwisolver-1.4.9-cp312-cp312-win_amd64.whl", hash = "sha256:f68208a520c3d86ea51acf688a3e3002615a7f0238002cccc17affecc86a8a54", size = 73895, upload-time = "2025-08-10T21:26:24.37Z" }, - { url = "https://files.pythonhosted.org/packages/e2/92/5f3068cf15ee5cb624a0c7596e67e2a0bb2adee33f71c379054a491d07da/kiwisolver-1.4.9-cp312-cp312-win_arm64.whl", hash = "sha256:2c1a4f57df73965f3f14df20b80ee29e6a7930a57d2d9e8491a25f676e197c60", size = 64992, upload-time = "2025-08-10T21:26:25.732Z" }, -] - -[[package]] -name = "markupsafe" -version = "3.0.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537, upload-time = "2024-10-18T15:21:54.129Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/22/09/d1f21434c97fc42f09d290cbb6350d44eb12f09cc62c9476effdb33a18aa/MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf", size = 14274, upload-time = "2024-10-18T15:21:13.777Z" }, - { url = "https://files.pythonhosted.org/packages/6b/b0/18f76bba336fa5aecf79d45dcd6c806c280ec44538b3c13671d49099fdd0/MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225", size = 12348, upload-time = "2024-10-18T15:21:14.822Z" }, - { url = "https://files.pythonhosted.org/packages/e0/25/dd5c0f6ac1311e9b40f4af06c78efde0f3b5cbf02502f8ef9501294c425b/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028", size = 24149, upload-time = "2024-10-18T15:21:15.642Z" }, - { url = "https://files.pythonhosted.org/packages/f3/f0/89e7aadfb3749d0f52234a0c8c7867877876e0a20b60e2188e9850794c17/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8", size = 23118, upload-time = "2024-10-18T15:21:17.133Z" }, - { url = "https://files.pythonhosted.org/packages/d5/da/f2eeb64c723f5e3777bc081da884b414671982008c47dcc1873d81f625b6/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c", size = 22993, upload-time = "2024-10-18T15:21:18.064Z" }, - { url = "https://files.pythonhosted.org/packages/da/0e/1f32af846df486dce7c227fe0f2398dc7e2e51d4a370508281f3c1c5cddc/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557", size = 24178, upload-time = "2024-10-18T15:21:18.859Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f6/bb3ca0532de8086cbff5f06d137064c8410d10779c4c127e0e47d17c0b71/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22", size = 23319, upload-time = "2024-10-18T15:21:19.671Z" }, - { url = "https://files.pythonhosted.org/packages/a2/82/8be4c96ffee03c5b4a034e60a31294daf481e12c7c43ab8e34a1453ee48b/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48", size = 23352, upload-time = "2024-10-18T15:21:20.971Z" }, - { url = "https://files.pythonhosted.org/packages/51/ae/97827349d3fcffee7e184bdf7f41cd6b88d9919c80f0263ba7acd1bbcb18/MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30", size = 15097, upload-time = "2024-10-18T15:21:22.646Z" }, - { url = "https://files.pythonhosted.org/packages/c1/80/a61f99dc3a936413c3ee4e1eecac96c0da5ed07ad56fd975f1a9da5bc630/MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87", size = 15601, upload-time = "2024-10-18T15:21:23.499Z" }, -] - -[[package]] -name = "matplotlib" -version = "3.10.6" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "contourpy" }, - { name = "cycler" }, - { name = "fonttools" }, - { name = "kiwisolver" }, - { name = "numpy" }, - { name = "packaging" }, - { name = "pillow" }, - { name = "pyparsing" }, - { name = "python-dateutil" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a0/59/c3e6453a9676ffba145309a73c462bb407f4400de7de3f2b41af70720a3c/matplotlib-3.10.6.tar.gz", hash = "sha256:ec01b645840dd1996df21ee37f208cd8ba57644779fa20464010638013d3203c", size = 34804264, upload-time = "2025-08-30T00:14:25.137Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/1a/7042f7430055d567cc3257ac409fcf608599ab27459457f13772c2d9778b/matplotlib-3.10.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:31ca662df6a80bd426f871105fdd69db7543e28e73a9f2afe80de7e531eb2347", size = 8272404, upload-time = "2025-08-30T00:12:59.112Z" }, - { url = "https://files.pythonhosted.org/packages/a9/5d/1d5f33f5b43f4f9e69e6a5fe1fb9090936ae7bc8e2ff6158e7a76542633b/matplotlib-3.10.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1678bb61d897bb4ac4757b5ecfb02bfb3fddf7f808000fb81e09c510712fda75", size = 8128262, upload-time = "2025-08-30T00:13:01.141Z" }, - { url = "https://files.pythonhosted.org/packages/67/c3/135fdbbbf84e0979712df58e5e22b4f257b3f5e52a3c4aacf1b8abec0d09/matplotlib-3.10.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:56cd2d20842f58c03d2d6e6c1f1cf5548ad6f66b91e1e48f814e4fb5abd1cb95", size = 8697008, upload-time = "2025-08-30T00:13:03.24Z" }, - { url = "https://files.pythonhosted.org/packages/9c/be/c443ea428fb2488a3ea7608714b1bd85a82738c45da21b447dc49e2f8e5d/matplotlib-3.10.6-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:662df55604a2f9a45435566d6e2660e41efe83cd94f4288dfbf1e6d1eae4b0bb", size = 9530166, upload-time = "2025-08-30T00:13:05.951Z" }, - { url = "https://files.pythonhosted.org/packages/a9/35/48441422b044d74034aea2a3e0d1a49023f12150ebc58f16600132b9bbaf/matplotlib-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:08f141d55148cd1fc870c3387d70ca4df16dee10e909b3b038782bd4bda6ea07", size = 9593105, upload-time = "2025-08-30T00:13:08.356Z" }, - { url = "https://files.pythonhosted.org/packages/45/c3/994ef20eb4154ab84cc08d033834555319e4af970165e6c8894050af0b3c/matplotlib-3.10.6-cp312-cp312-win_amd64.whl", hash = "sha256:590f5925c2d650b5c9d813c5b3b5fc53f2929c3f8ef463e4ecfa7e052044fb2b", size = 8122784, upload-time = "2025-08-30T00:13:10.367Z" }, - { url = "https://files.pythonhosted.org/packages/57/b8/5c85d9ae0e40f04e71bedb053aada5d6bab1f9b5399a0937afb5d6b02d98/matplotlib-3.10.6-cp312-cp312-win_arm64.whl", hash = "sha256:f44c8d264a71609c79a78d50349e724f5d5fc3684ead7c2a473665ee63d868aa", size = 7992823, upload-time = "2025-08-30T00:13:12.24Z" }, -] - -[[package]] -name = "mpmath" -version = "1.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, -] - -[[package]] -name = "mypy-extensions" -version = "1.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, -] - -[[package]] -name = "networkx" -version = "3.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" }, -] - -[[package]] -name = "numpy" -version = "2.2.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/76/21/7d2a95e4bba9dc13d043ee156a356c0a8f0c6309dff6b21b4d71a073b8a8/numpy-2.2.6.tar.gz", hash = "sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd", size = 20276440, upload-time = "2025-05-17T22:38:04.611Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/82/5d/c00588b6cf18e1da539b45d3598d3557084990dcc4331960c15ee776ee41/numpy-2.2.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:41c5a21f4a04fa86436124d388f6ed60a9343a6f767fced1a8a71c3fbca038ff", size = 20875348, upload-time = "2025-05-17T21:34:39.648Z" }, - { url = "https://files.pythonhosted.org/packages/66/ee/560deadcdde6c2f90200450d5938f63a34b37e27ebff162810f716f6a230/numpy-2.2.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de749064336d37e340f640b05f24e9e3dd678c57318c7289d222a8a2f543e90c", size = 14119362, upload-time = "2025-05-17T21:35:01.241Z" }, - { url = "https://files.pythonhosted.org/packages/3c/65/4baa99f1c53b30adf0acd9a5519078871ddde8d2339dc5a7fde80d9d87da/numpy-2.2.6-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:894b3a42502226a1cac872f840030665f33326fc3dac8e57c607905773cdcde3", size = 5084103, upload-time = "2025-05-17T21:35:10.622Z" }, - { url = "https://files.pythonhosted.org/packages/cc/89/e5a34c071a0570cc40c9a54eb472d113eea6d002e9ae12bb3a8407fb912e/numpy-2.2.6-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:71594f7c51a18e728451bb50cc60a3ce4e6538822731b2933209a1f3614e9282", size = 6625382, upload-time = "2025-05-17T21:35:21.414Z" }, - { url = "https://files.pythonhosted.org/packages/f8/35/8c80729f1ff76b3921d5c9487c7ac3de9b2a103b1cd05e905b3090513510/numpy-2.2.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2618db89be1b4e05f7a1a847a9c1c0abd63e63a1607d892dd54668dd92faf87", size = 14018462, upload-time = "2025-05-17T21:35:42.174Z" }, - { url = "https://files.pythonhosted.org/packages/8c/3d/1e1db36cfd41f895d266b103df00ca5b3cbe965184df824dec5c08c6b803/numpy-2.2.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd83c01228a688733f1ded5201c678f0c53ecc1006ffbc404db9f7a899ac6249", size = 16527618, upload-time = "2025-05-17T21:36:06.711Z" }, - { url = "https://files.pythonhosted.org/packages/61/c6/03ed30992602c85aa3cd95b9070a514f8b3c33e31124694438d88809ae36/numpy-2.2.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:37c0ca431f82cd5fa716eca9506aefcabc247fb27ba69c5062a6d3ade8cf8f49", size = 15505511, upload-time = "2025-05-17T21:36:29.965Z" }, - { url = "https://files.pythonhosted.org/packages/b7/25/5761d832a81df431e260719ec45de696414266613c9ee268394dd5ad8236/numpy-2.2.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fe27749d33bb772c80dcd84ae7e8df2adc920ae8297400dabec45f0dedb3f6de", size = 18313783, upload-time = "2025-05-17T21:36:56.883Z" }, - { url = "https://files.pythonhosted.org/packages/57/0a/72d5a3527c5ebffcd47bde9162c39fae1f90138c961e5296491ce778e682/numpy-2.2.6-cp312-cp312-win32.whl", hash = "sha256:4eeaae00d789f66c7a25ac5f34b71a7035bb474e679f410e5e1a94deb24cf2d4", size = 6246506, upload-time = "2025-05-17T21:37:07.368Z" }, - { url = "https://files.pythonhosted.org/packages/36/fa/8c9210162ca1b88529ab76b41ba02d433fd54fecaf6feb70ef9f124683f1/numpy-2.2.6-cp312-cp312-win_amd64.whl", hash = "sha256:c1f9540be57940698ed329904db803cf7a402f3fc200bfe599334c9bd84a40b2", size = 12614190, upload-time = "2025-05-17T21:37:26.213Z" }, -] - -[[package]] -name = "nvidia-cublas-cu12" -version = "12.4.5.8" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805, upload-time = "2024-04-03T20:57:06.025Z" }, -] - -[[package]] -name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957, upload-time = "2024-04-03T20:55:01.564Z" }, -] - -[[package]] -name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306, upload-time = "2024-04-03T20:56:01.463Z" }, -] - -[[package]] -name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737, upload-time = "2024-04-03T20:54:51.355Z" }, -] - -[[package]] -name = "nvidia-cudnn-cu12" -version = "9.1.0.70" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741, upload-time = "2024-04-22T15:24:15.253Z" }, -] - -[[package]] -name = "nvidia-cufft-cu12" -version = "11.2.1.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117, upload-time = "2024-04-03T20:57:40.402Z" }, -] - -[[package]] -name = "nvidia-curand-cu12" -version = "10.3.5.147" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206, upload-time = "2024-04-03T20:58:08.722Z" }, -] - -[[package]] -name = "nvidia-cusolver-cu12" -version = "11.6.1.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057, upload-time = "2024-04-03T20:58:28.735Z" }, -] - -[[package]] -name = "nvidia-cusparse-cu12" -version = "12.3.1.170" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763, upload-time = "2024-04-03T20:58:59.995Z" }, -] - -[[package]] -name = "nvidia-cusparselt-cu12" -version = "0.6.2" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751, upload-time = "2024-07-23T02:35:53.074Z" }, -] - -[[package]] -name = "nvidia-nccl-cu12" -version = "2.21.5" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414, upload-time = "2024-04-03T15:32:57.427Z" }, -] - -[[package]] -name = "nvidia-nvjitlink-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810, upload-time = "2024-04-03T20:59:46.957Z" }, -] - -[[package]] -name = "nvidia-nvtx-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144, upload-time = "2024-04-03T20:56:12.406Z" }, -] - -[[package]] -name = "omegaconf" -version = "2.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "antlr4-python3-runtime" }, - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120, upload-time = "2022-12-08T20:59:22.753Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" }, -] - -[[package]] -name = "opencv-python" -version = "4.12.0.88" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ac/71/25c98e634b6bdeca4727c7f6d6927b056080668c5008ad3c8fc9e7f8f6ec/opencv-python-4.12.0.88.tar.gz", hash = "sha256:8b738389cede219405f6f3880b851efa3415ccd674752219377353f017d2994d", size = 95373294, upload-time = "2025-07-07T09:20:52.389Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/85/68/3da40142e7c21e9b1d4e7ddd6c58738feb013203e6e4b803d62cdd9eb96b/opencv_python-4.12.0.88-cp37-abi3-macosx_13_0_arm64.whl", hash = "sha256:f9a1f08883257b95a5764bf517a32d75aec325319c8ed0f89739a57fae9e92a5", size = 37877727, upload-time = "2025-07-07T09:13:31.47Z" }, - { url = "https://files.pythonhosted.org/packages/33/7c/042abe49f58d6ee7e1028eefc3334d98ca69b030e3b567fe245a2b28ea6f/opencv_python-4.12.0.88-cp37-abi3-macosx_13_0_x86_64.whl", hash = "sha256:812eb116ad2b4de43ee116fcd8991c3a687f099ada0b04e68f64899c09448e81", size = 57326471, upload-time = "2025-07-07T09:13:41.26Z" }, - { url = "https://files.pythonhosted.org/packages/62/3a/440bd64736cf8116f01f3b7f9f2e111afb2e02beb2ccc08a6458114a6b5d/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:51fd981c7df6af3e8f70b1556696b05224c4e6b6777bdd2a46b3d4fb09de1a92", size = 45887139, upload-time = "2025-07-07T09:13:50.761Z" }, - { url = "https://files.pythonhosted.org/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:092c16da4c5a163a818f120c22c5e4a2f96e0db4f24e659c701f1fe629a690f9", size = 67041680, upload-time = "2025-07-07T09:14:01.995Z" }, - { url = "https://files.pythonhosted.org/packages/02/96/213fea371d3cb2f1d537612a105792aa0a6659fb2665b22cad709a75bd94/opencv_python-4.12.0.88-cp37-abi3-win32.whl", hash = "sha256:ff554d3f725b39878ac6a2e1fa232ec509c36130927afc18a1719ebf4fbf4357", size = 30284131, upload-time = "2025-07-07T09:14:08.819Z" }, - { url = "https://files.pythonhosted.org/packages/fa/80/eb88edc2e2b11cd2dd2e56f1c80b5784d11d6e6b7f04a1145df64df40065/opencv_python-4.12.0.88-cp37-abi3-win_amd64.whl", hash = "sha256:d98edb20aa932fd8ebd276a72627dad9dc097695b3d435a4257557bbb49a79d2", size = 39000307, upload-time = "2025-07-07T09:14:16.641Z" }, -] - -[[package]] -name = "packaging" -version = "25.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, -] - -[[package]] -name = "pandas" -version = "2.3.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, - { name = "python-dateutil" }, - { name = "pytz" }, - { name = "tzdata" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/79/8e/0e90233ac205ad182bd6b422532695d2b9414944a280488105d598c70023/pandas-2.3.2.tar.gz", hash = "sha256:ab7b58f8f82706890924ccdfb5f48002b83d2b5a3845976a9fb705d36c34dcdb", size = 4488684, upload-time = "2025-08-21T10:28:29.257Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/db/614c20fb7a85a14828edd23f1c02db58a30abf3ce76f38806155d160313c/pandas-2.3.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fbb977f802156e7a3f829e9d1d5398f6192375a3e2d1a9ee0803e35fe70a2b9", size = 11587652, upload-time = "2025-08-21T10:27:15.888Z" }, - { url = "https://files.pythonhosted.org/packages/99/b0/756e52f6582cade5e746f19bad0517ff27ba9c73404607c0306585c201b3/pandas-2.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b9b52693123dd234b7c985c68b709b0b009f4521000d0525f2b95c22f15944b", size = 10717686, upload-time = "2025-08-21T10:27:18.486Z" }, - { url = "https://files.pythonhosted.org/packages/37/4c/dd5ccc1e357abfeee8353123282de17997f90ff67855f86154e5a13b81e5/pandas-2.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bd281310d4f412733f319a5bc552f86d62cddc5f51d2e392c8787335c994175", size = 11278722, upload-time = "2025-08-21T10:27:21.149Z" }, - { url = "https://files.pythonhosted.org/packages/d3/a4/f7edcfa47e0a88cda0be8b068a5bae710bf264f867edfdf7b71584ace362/pandas-2.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96d31a6b4354e3b9b8a2c848af75d31da390657e3ac6f30c05c82068b9ed79b9", size = 11987803, upload-time = "2025-08-21T10:27:23.767Z" }, - { url = "https://files.pythonhosted.org/packages/f6/61/1bce4129f93ab66f1c68b7ed1c12bac6a70b1b56c5dab359c6bbcd480b52/pandas-2.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:df4df0b9d02bb873a106971bb85d448378ef14b86ba96f035f50bbd3688456b4", size = 12766345, upload-time = "2025-08-21T10:27:26.6Z" }, - { url = "https://files.pythonhosted.org/packages/8e/46/80d53de70fee835531da3a1dae827a1e76e77a43ad22a8cd0f8142b61587/pandas-2.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:213a5adf93d020b74327cb2c1b842884dbdd37f895f42dcc2f09d451d949f811", size = 13439314, upload-time = "2025-08-21T10:27:29.213Z" }, - { url = "https://files.pythonhosted.org/packages/28/30/8114832daff7489f179971dbc1d854109b7f4365a546e3ea75b6516cea95/pandas-2.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:8c13b81a9347eb8c7548f53fd9a4f08d4dfe996836543f805c987bafa03317ae", size = 10983326, upload-time = "2025-08-21T10:27:31.901Z" }, -] - -[[package]] -name = "parso" -version = "0.8.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d4/de/53e0bcf53d13e005bd8c92e7855142494f41171b34c2536b86187474184d/parso-0.8.5.tar.gz", hash = "sha256:034d7354a9a018bdce352f48b2a8a450f05e9d6ee85db84764e9b6bd96dafe5a", size = 401205, upload-time = "2025-08-23T15:15:28.028Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl", hash = "sha256:646204b5ee239c396d040b90f9e272e9a8017c630092bf59980beb62fd033887", size = 106668, upload-time = "2025-08-23T15:15:25.663Z" }, -] - -[[package]] -name = "pathspec" -version = "0.12.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, -] - -[[package]] -name = "pillow" -version = "11.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/d0d6dea55cd152ce3d6767bb38a8fc10e33796ba4ba210cbab9354b6d238/pillow-11.3.0.tar.gz", hash = "sha256:3828ee7586cd0b2091b6209e5ad53e20d0649bbe87164a459d0676e035e8f523", size = 47113069, upload-time = "2025-07-01T09:16:30.666Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/fe/1bc9b3ee13f68487a99ac9529968035cca2f0a51ec36892060edcc51d06a/pillow-11.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdae223722da47b024b867c1ea0be64e0df702c5e0a60e27daad39bf960dd1e4", size = 5278800, upload-time = "2025-07-01T09:14:17.648Z" }, - { url = "https://files.pythonhosted.org/packages/2c/32/7e2ac19b5713657384cec55f89065fb306b06af008cfd87e572035b27119/pillow-11.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:921bd305b10e82b4d1f5e802b6850677f965d8394203d182f078873851dada69", size = 4686296, upload-time = "2025-07-01T09:14:19.828Z" }, - { url = "https://files.pythonhosted.org/packages/8e/1e/b9e12bbe6e4c2220effebc09ea0923a07a6da1e1f1bfbc8d7d29a01ce32b/pillow-11.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:eb76541cba2f958032d79d143b98a3a6b3ea87f0959bbe256c0b5e416599fd5d", size = 5871726, upload-time = "2025-07-03T13:10:04.448Z" }, - { url = "https://files.pythonhosted.org/packages/8d/33/e9200d2bd7ba00dc3ddb78df1198a6e80d7669cce6c2bdbeb2530a74ec58/pillow-11.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:67172f2944ebba3d4a7b54f2e95c786a3a50c21b88456329314caaa28cda70f6", size = 7644652, upload-time = "2025-07-03T13:10:10.391Z" }, - { url = "https://files.pythonhosted.org/packages/41/f1/6f2427a26fc683e00d985bc391bdd76d8dd4e92fac33d841127eb8fb2313/pillow-11.3.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97f07ed9f56a3b9b5f49d3661dc9607484e85c67e27f3e8be2c7d28ca032fec7", size = 5977787, upload-time = "2025-07-01T09:14:21.63Z" }, - { url = "https://files.pythonhosted.org/packages/e4/c9/06dd4a38974e24f932ff5f98ea3c546ce3f8c995d3f0985f8e5ba48bba19/pillow-11.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:676b2815362456b5b3216b4fd5bd89d362100dc6f4945154ff172e206a22c024", size = 6645236, upload-time = "2025-07-01T09:14:23.321Z" }, - { url = "https://files.pythonhosted.org/packages/40/e7/848f69fb79843b3d91241bad658e9c14f39a32f71a301bcd1d139416d1be/pillow-11.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3e184b2f26ff146363dd07bde8b711833d7b0202e27d13540bfe2e35a323a809", size = 6086950, upload-time = "2025-07-01T09:14:25.237Z" }, - { url = "https://files.pythonhosted.org/packages/0b/1a/7cff92e695a2a29ac1958c2a0fe4c0b2393b60aac13b04a4fe2735cad52d/pillow-11.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6be31e3fc9a621e071bc17bb7de63b85cbe0bfae91bb0363c893cbe67247780d", size = 6723358, upload-time = "2025-07-01T09:14:27.053Z" }, - { url = "https://files.pythonhosted.org/packages/26/7d/73699ad77895f69edff76b0f332acc3d497f22f5d75e5360f78cbcaff248/pillow-11.3.0-cp312-cp312-win32.whl", hash = "sha256:7b161756381f0918e05e7cb8a371fff367e807770f8fe92ecb20d905d0e1c149", size = 6275079, upload-time = "2025-07-01T09:14:30.104Z" }, - { url = "https://files.pythonhosted.org/packages/8c/ce/e7dfc873bdd9828f3b6e5c2bbb74e47a98ec23cc5c74fc4e54462f0d9204/pillow-11.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:a6444696fce635783440b7f7a9fc24b3ad10a9ea3f0ab66c5905be1c19ccf17d", size = 6986324, upload-time = "2025-07-01T09:14:31.899Z" }, - { url = "https://files.pythonhosted.org/packages/16/8f/b13447d1bf0b1f7467ce7d86f6e6edf66c0ad7cf44cf5c87a37f9bed9936/pillow-11.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:2aceea54f957dd4448264f9bf40875da0415c83eb85f55069d89c0ed436e3542", size = 2423067, upload-time = "2025-07-01T09:14:33.709Z" }, -] - -[[package]] -name = "platformdirs" -version = "4.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/23/e8/21db9c9987b0e728855bd57bff6984f67952bea55d6f75e055c46b5383e8/platformdirs-4.4.0.tar.gz", hash = "sha256:ca753cf4d81dc309bc67b0ea38fd15dc97bc30ce419a7f58d13eb3bf14c4febf", size = 21634, upload-time = "2025-08-26T14:32:04.268Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/4b/2028861e724d3bd36227adfa20d3fd24c3fc6d52032f4a93c133be5d17ce/platformdirs-4.4.0-py3-none-any.whl", hash = "sha256:abd01743f24e5287cd7a5db3752faf1a2d65353f38ec26d98e25a6db65958c85", size = 18654, upload-time = "2025-08-26T14:32:02.735Z" }, -] - -[[package]] -name = "pluggy" -version = "1.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, -] - -[[package]] -name = "protobuf" -version = "6.32.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c0/df/fb4a8eeea482eca989b51cffd274aac2ee24e825f0bf3cbce5281fa1567b/protobuf-6.32.0.tar.gz", hash = "sha256:a81439049127067fc49ec1d36e25c6ee1d1a2b7be930675f919258d03c04e7d2", size = 440614, upload-time = "2025-08-14T21:21:25.015Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/33/18/df8c87da2e47f4f1dcc5153a81cd6bca4e429803f4069a299e236e4dd510/protobuf-6.32.0-cp310-abi3-win32.whl", hash = "sha256:84f9e3c1ff6fb0308dbacb0950d8aa90694b0d0ee68e75719cb044b7078fe741", size = 424409, upload-time = "2025-08-14T21:21:12.366Z" }, - { url = "https://files.pythonhosted.org/packages/e1/59/0a820b7310f8139bd8d5a9388e6a38e1786d179d6f33998448609296c229/protobuf-6.32.0-cp310-abi3-win_amd64.whl", hash = "sha256:a8bdbb2f009cfc22a36d031f22a625a38b615b5e19e558a7b756b3279723e68e", size = 435735, upload-time = "2025-08-14T21:21:15.046Z" }, - { url = "https://files.pythonhosted.org/packages/cc/5b/0d421533c59c789e9c9894683efac582c06246bf24bb26b753b149bd88e4/protobuf-6.32.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d52691e5bee6c860fff9a1c86ad26a13afbeb4b168cd4445c922b7e2cf85aaf0", size = 426449, upload-time = "2025-08-14T21:21:16.687Z" }, - { url = "https://files.pythonhosted.org/packages/ec/7b/607764ebe6c7a23dcee06e054fd1de3d5841b7648a90fd6def9a3bb58c5e/protobuf-6.32.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:501fe6372fd1c8ea2a30b4d9be8f87955a64d6be9c88a973996cef5ef6f0abf1", size = 322869, upload-time = "2025-08-14T21:21:18.282Z" }, - { url = "https://files.pythonhosted.org/packages/40/01/2e730bd1c25392fc32e3268e02446f0d77cb51a2c3a8486b1798e34d5805/protobuf-6.32.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:75a2aab2bd1aeb1f5dc7c5f33bcb11d82ea8c055c9becbb41c26a8c43fd7092c", size = 322009, upload-time = "2025-08-14T21:21:19.893Z" }, - { url = "https://files.pythonhosted.org/packages/9c/f2/80ffc4677aac1bc3519b26bc7f7f5de7fce0ee2f7e36e59e27d8beb32dd1/protobuf-6.32.0-py3-none-any.whl", hash = "sha256:ba377e5b67b908c8f3072a57b63e2c6a4cbd18aea4ed98d2584350dbf46f2783", size = 169287, upload-time = "2025-08-14T21:21:23.515Z" }, -] - -[[package]] -name = "pudb" -version = "2025.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jedi" }, - { name = "packaging" }, - { name = "pygments" }, - { name = "urwid" }, - { name = "urwid-readline" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/93/40/8d17b16a1c2a36d8cc1befb5eda13310ad20fcb347e58bef5e104696e14a/pudb-2025.1.tar.gz", hash = "sha256:a528b29c69ce8b182a337872c5f046071f6d68d3415c6d7bf53bd27c264f58d0", size = 220623, upload-time = "2025-05-06T20:43:18.306Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/01/069766294390d3e10c77dfb553171466d67ffb51bf72a437650c0a5db86a/pudb-2025.1-py3-none-any.whl", hash = "sha256:f642d42e6054c992b43c463742650aa879fe290d7d7ffdeb21f7d00dc4587a21", size = 89208, upload-time = "2025-05-06T20:43:17.101Z" }, -] - -[[package]] -name = "pycparser" -version = "2.23" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/cf/d2d3b9f5699fb1e4615c8e32ff220203e43b248e1dfcc6736ad9057731ca/pycparser-2.23.tar.gz", hash = "sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2", size = 173734, upload-time = "2025-09-09T13:23:47.91Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/e3/59cd50310fc9b59512193629e1984c1f95e5c8ae6e5d8c69532ccc65a7fe/pycparser-2.23-py3-none-any.whl", hash = "sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934", size = 118140, upload-time = "2025-09-09T13:23:46.651Z" }, -] - -[[package]] -name = "pydantic" -version = "2.11.7" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "annotated-types" }, - { name = "pydantic-core" }, - { name = "typing-extensions" }, - { name = "typing-inspection" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/00/dd/4325abf92c39ba8623b5af936ddb36ffcfe0beae70405d456ab1fb2f5b8c/pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db", size = 788350, upload-time = "2025-06-14T08:33:17.137Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/c0/ec2b1c8712ca690e5d61979dee872603e92b8a32f94cc1b72d53beab008a/pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b", size = 444782, upload-time = "2025-06-14T08:33:14.905Z" }, -] - -[[package]] -name = "pydantic-core" -version = "2.33.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000, upload-time = "2025-04-23T18:31:25.863Z" }, - { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996, upload-time = "2025-04-23T18:31:27.341Z" }, - { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957, upload-time = "2025-04-23T18:31:28.956Z" }, - { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199, upload-time = "2025-04-23T18:31:31.025Z" }, - { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296, upload-time = "2025-04-23T18:31:32.514Z" }, - { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109, upload-time = "2025-04-23T18:31:33.958Z" }, - { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028, upload-time = "2025-04-23T18:31:39.095Z" }, - { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044, upload-time = "2025-04-23T18:31:41.034Z" }, - { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881, upload-time = "2025-04-23T18:31:42.757Z" }, - { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034, upload-time = "2025-04-23T18:31:44.304Z" }, - { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187, upload-time = "2025-04-23T18:31:45.891Z" }, - { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628, upload-time = "2025-04-23T18:31:47.819Z" }, - { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866, upload-time = "2025-04-23T18:31:49.635Z" }, - { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894, upload-time = "2025-04-23T18:31:51.609Z" }, -] - -[[package]] -name = "pygments" -version = "2.19.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, -] - -[[package]] -name = "pymunk" -version = "7.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cffi" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/57/25/ee7ff873cec8aee64751b284c7cba2bd3e7364f67b9bd7d41b16a8ab358f/pymunk-7.1.0.tar.gz", hash = "sha256:f3045897261325bb3e89284402dd43b908d0a5871dc2010597e36b4309f0229b", size = 3355569, upload-time = "2025-06-30T19:39:08.81Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/71/a6/f9137df251de8500eb916c80416872e3ddd344f34a23c4c6e8ac27ac4abc/pymunk-7.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:226d2a45942284bf472580be487a84af4d32fbfaf11ec1a5f36d04a7ce8b0dbd", size = 367321, upload-time = "2025-06-29T20:06:01.608Z" }, - { url = "https://files.pythonhosted.org/packages/0c/98/098b87056e32ba8d2a412ba6f6cd8ad5b74dbee357ff509c7c9e8c437181/pymunk-7.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dfacf19a1b6815a571a87308777a73af05cbcfede9357c5143410b027d023bc1", size = 350136, upload-time = "2025-06-29T20:06:03.499Z" }, - { url = "https://files.pythonhosted.org/packages/c8/d3/7061ab9ec86582cf2c91d269ca32fc57ae59ab585dfd06a7aa490d3d4c98/pymunk-7.1.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc259973248fda635173c09c5fc96901fb916ec8d2977b5ee4aaa8e70293f94b", size = 993324, upload-time = "2025-06-29T20:06:06.097Z" }, - { url = "https://files.pythonhosted.org/packages/89/0a/a4a2e6d09f0cbbdca06c251865d8ec380f3c498c6b21df1c4113bf0596e2/pymunk-7.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6639d2e3242650e52227690931dbe43ab1aa29f24fb456aafb6b70b3027ea744", size = 1017864, upload-time = "2025-06-29T20:06:09.298Z" }, - { url = "https://files.pythonhosted.org/packages/69/c8/7a2769fd106fd6f2c63b7f5ef341b09b8bf51e72b6f02f0d52ce94ad88c5/pymunk-7.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8d47bc3c82340e32e2ac28f3f12261303a81fd1e3ddf4597ab723c15ada39285", size = 989245, upload-time = "2025-06-29T20:06:12.04Z" }, - { url = "https://files.pythonhosted.org/packages/8f/48/9ae814e4e158cb7ba01afc4f294e8b27919feb6919caa1ff7f2e8afd3abf/pymunk-7.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:58a47d716aa18695b5e8611d43061e323793913d25b019afa00c99d035a18671", size = 1016204, upload-time = "2025-06-29T20:06:14.989Z" }, - { url = "https://files.pythonhosted.org/packages/77/2e/bf0976764fac254efed9142880e8e19d3400a1f1ef2a313c30de87512ba7/pymunk-7.1.0-cp312-cp312-win32.whl", hash = "sha256:13359a0ae1bc069318b36c8e3eb69d72bac72153e6f168410d01115cc1f252ee", size = 317747, upload-time = "2025-06-29T20:06:16.916Z" }, - { url = "https://files.pythonhosted.org/packages/e1/0b/bee0cea05387ce65fad9714e034da579135ece6edd176e734078be1354cc/pymunk-7.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:b00b50c9963e2743bff37f9319a44e42a5ee9f7dd6755d9bc9b6decef13e0f63", size = 368267, upload-time = "2025-06-29T20:06:18.364Z" }, -] - -[[package]] -name = "pyparsing" -version = "3.2.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bb/22/f1129e69d94ffff626bdb5c835506b3a5b4f3d070f17ea295e12c2c6f60f/pyparsing-3.2.3.tar.gz", hash = "sha256:b9c13f1ab8b3b542f72e28f634bad4de758ab3ce4546e4301970ad6fa77c38be", size = 1088608, upload-time = "2025-03-25T05:01:28.114Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120, upload-time = "2025-03-25T05:01:24.908Z" }, -] - -[[package]] -name = "pytest" -version = "8.4.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "iniconfig" }, - { name = "packaging" }, - { name = "pluggy" }, - { name = "pygments" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, -] - -[[package]] -name = "python-dateutil" -version = "2.9.0.post0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, -] - -[[package]] -name = "pytz" -version = "2025.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, -] - -[[package]] -name = "pyyaml" -version = "6.0.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631, upload-time = "2024-08-06T20:33:50.674Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/86/0c/c581167fc46d6d6d7ddcfb8c843a4de25bdd27e4466938109ca68492292c/PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", size = 183873, upload-time = "2024-08-06T20:32:25.131Z" }, - { url = "https://files.pythonhosted.org/packages/a8/0c/38374f5bb272c051e2a69281d71cba6fdb983413e6758b84482905e29a5d/PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", size = 173302, upload-time = "2024-08-06T20:32:26.511Z" }, - { url = "https://files.pythonhosted.org/packages/c3/93/9916574aa8c00aa06bbac729972eb1071d002b8e158bd0e83a3b9a20a1f7/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", size = 739154, upload-time = "2024-08-06T20:32:28.363Z" }, - { url = "https://files.pythonhosted.org/packages/95/0f/b8938f1cbd09739c6da569d172531567dbcc9789e0029aa070856f123984/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", size = 766223, upload-time = "2024-08-06T20:32:30.058Z" }, - { url = "https://files.pythonhosted.org/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", size = 767542, upload-time = "2024-08-06T20:32:31.881Z" }, - { url = "https://files.pythonhosted.org/packages/d4/00/dd137d5bcc7efea1836d6264f049359861cf548469d18da90cd8216cf05f/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", size = 731164, upload-time = "2024-08-06T20:32:37.083Z" }, - { url = "https://files.pythonhosted.org/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611, upload-time = "2024-08-06T20:32:38.898Z" }, - { url = "https://files.pythonhosted.org/packages/df/d1/f5a275fdb252768b7a11ec63585bc38d0e87c9e05668a139fea92b80634c/PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", size = 140591, upload-time = "2024-08-06T20:32:40.241Z" }, - { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338, upload-time = "2024-08-06T20:32:41.93Z" }, -] - -[[package]] -name = "regex" -version = "2025.8.29" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e4/10/2d333227cf5198eb3252f2d50c8ade5cd2015f11c22403f0c9e3d529e81a/regex-2025.8.29.tar.gz", hash = "sha256:731ddb27a0900fa227dfba976b4efccec8c1c6fba147829bb52e71d49e91a5d7", size = 400817, upload-time = "2025-08-29T22:43:36.985Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/a0/8c37d276a80ffda94f7e019e50cc88f898015512c7f104e49f1a0a6d3c59/regex-2025.8.29-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:dd61f18dc4446bc3a2904559a61f32e98091cef7fb796e06fa35b9bfefe4c0c5", size = 485565, upload-time = "2025-08-29T22:41:41.069Z" }, - { url = "https://files.pythonhosted.org/packages/5d/34/baf5963bec36ac250fa242f0f0e7670f013de5004db6caa31c872981df42/regex-2025.8.29-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f21b416be10a8348a7313ba8c610569a1ab4bf8ec70731750540842a4551cd3d", size = 290073, upload-time = "2025-08-29T22:41:42.686Z" }, - { url = "https://files.pythonhosted.org/packages/24/29/c5c18143cd60b736d7ff8acece126118fe5649f45a7a8db18e308f5f813d/regex-2025.8.29-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:008947a7fa92f4cb3b28201c9aa7becc0a44c31a7c2fcb934356e1877baccc09", size = 286144, upload-time = "2025-08-29T22:41:44.364Z" }, - { url = "https://files.pythonhosted.org/packages/86/7c/0d90b687d2a33fe28b201f85ddfde6b378bf41677aedbe23eb7dc79385aa/regex-2025.8.29-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e78ab1b3e68b890d7ebd69218cfbfe4a09dc00b8a47be8648510b81b932d55ff", size = 797417, upload-time = "2025-08-29T22:41:47.224Z" }, - { url = "https://files.pythonhosted.org/packages/fb/67/c391c899e5ef274c4dd4ede029ffb853ddf5ba77aa251be02cfe3810574c/regex-2025.8.29-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a848368797515bc141d3fad5fd2d81bf9e8a6a22d9ac1a4be4690dd22e997854", size = 862630, upload-time = "2025-08-29T22:41:48.891Z" }, - { url = "https://files.pythonhosted.org/packages/08/20/ae749a68da3496a133836c8724649bd2e004fc176c7c6647d9cb269cc975/regex-2025.8.29-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8eaf3ea6631f804efcf0f5bd0e4ab62ba984fd9b70e3aef44b05cc6b951cc728", size = 910837, upload-time = "2025-08-29T22:41:50.592Z" }, - { url = "https://files.pythonhosted.org/packages/e2/80/bc4244ec79fba4185fd3a29d79f77f79b3b0dc12ee426687501b0b077e2a/regex-2025.8.29-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4561aeb36b0bf3bb44826e4b61a80c6ace0d8839bf4914d78f061f9ba61444b4", size = 801968, upload-time = "2025-08-29T22:41:54.239Z" }, - { url = "https://files.pythonhosted.org/packages/ef/bd/a2d75042bb1d3c9997e22bc0051cb9791a405589d6293c874f7c2ba487e7/regex-2025.8.29-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:93e077d1fbd24033fa427eab43d80ad47e449d25700cda78e8cac821a30090bf", size = 786626, upload-time = "2025-08-29T22:41:56.158Z" }, - { url = "https://files.pythonhosted.org/packages/24/ab/19cec75bf7d335cc7595d4857591455de118f6bfb563e6731c31f4fe33c3/regex-2025.8.29-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d92379e53d782bdb773988687300e3bccb91ad38157b754b04b1857aaeea16a3", size = 856532, upload-time = "2025-08-29T22:41:58.057Z" }, - { url = "https://files.pythonhosted.org/packages/b6/3d/517cd0b0f4b8330164d03ef0eafdd61ee839f82b891fcd8c571d5c727117/regex-2025.8.29-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:d41726de2040c2a487bbac70fdd6e3ff2f1aa47dc91f0a29f6955a6dfa0f06b6", size = 848977, upload-time = "2025-08-29T22:42:00.346Z" }, - { url = "https://files.pythonhosted.org/packages/ae/fc/b57e2644d87d038d7302f359f4042bf7092bd8259a3ae999adf236e6fbc0/regex-2025.8.29-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1915dfda52bd4d466f3a66b66988db1f647ee1d9c605858640ceeb779cffd908", size = 788112, upload-time = "2025-08-29T22:42:02.008Z" }, - { url = "https://files.pythonhosted.org/packages/a9/2f/70737feddbd33ec9f3f0cb8b38e7fc89304eccc80fd693d79a6f336e2282/regex-2025.8.29-cp312-cp312-win32.whl", hash = "sha256:e2ef0087ad6949918836f215480a9331f6c59ad54912a9a412f08ab1c9ccbc98", size = 264487, upload-time = "2025-08-29T22:42:04.401Z" }, - { url = "https://files.pythonhosted.org/packages/2f/f5/8832d05ecc5a7f80043e7521ea55adfa2d9b9ac0e646474153e7e13722c2/regex-2025.8.29-cp312-cp312-win_amd64.whl", hash = "sha256:c15d361fe9800bf38ef69c2e0c4b8b961ae4ce2f076fcf4f28e1fc9ea127f55a", size = 275455, upload-time = "2025-08-29T22:42:06.312Z" }, - { url = "https://files.pythonhosted.org/packages/a5/f9/f10ae0c4e5e22db75dda155d83056e2b70c4e87b04ad9838723ff5057e90/regex-2025.8.29-cp312-cp312-win_arm64.whl", hash = "sha256:305577fab545e64fb84d9a24269aa3132dbe05e1d7fa74b3614e93ec598fe6e6", size = 268558, upload-time = "2025-08-29T22:42:08.062Z" }, -] - -[[package]] -name = "requests" -version = "2.32.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "charset-normalizer" }, - { name = "idna" }, - { name = "urllib3" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, -] - -[[package]] -name = "ruamel-yaml" -version = "0.18.15" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "ruamel-yaml-clib", marker = "platform_python_implementation == 'CPython'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/3e/db/f3950f5e5031b618aae9f423a39bf81a55c148aecd15a34527898e752cf4/ruamel.yaml-0.18.15.tar.gz", hash = "sha256:dbfca74b018c4c3fba0b9cc9ee33e53c371194a9000e694995e620490fd40700", size = 146865, upload-time = "2025-08-19T11:15:10.694Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/e5/f2a0621f1781b76a38194acae72f01e37b1941470407345b6e8653ad7640/ruamel.yaml-0.18.15-py3-none-any.whl", hash = "sha256:148f6488d698b7a5eded5ea793a025308b25eca97208181b6a026037f391f701", size = 119702, upload-time = "2025-08-19T11:15:07.696Z" }, -] - -[[package]] -name = "ruamel-yaml-clib" -version = "0.2.14" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/e9/39ec4d4b3f91188fad1842748f67d4e749c77c37e353c4e545052ee8e893/ruamel.yaml.clib-0.2.14.tar.gz", hash = "sha256:803f5044b13602d58ea378576dd75aa759f52116a0232608e8fdada4da33752e", size = 225394, upload-time = "2025-09-22T19:51:23.753Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b4/42/ccfb34a25289afbbc42017e4d3d4288e61d35b2e00cfc6b92974a6a1f94b/ruamel.yaml.clib-0.2.14-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:6aeadc170090ff1889f0d2c3057557f9cd71f975f17535c26a5d37af98f19c27", size = 271775, upload-time = "2025-09-23T14:24:12.771Z" }, - { url = "https://files.pythonhosted.org/packages/82/73/e628a92e80197ff6a79ab81ec3fa00d4cc082d58ab78d3337b7ba7043301/ruamel.yaml.clib-0.2.14-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5e56ac47260c0eed992789fa0b8efe43404a9adb608608631a948cee4fc2b052", size = 138842, upload-time = "2025-09-22T19:50:49.156Z" }, - { url = "https://files.pythonhosted.org/packages/2b/c5/346c7094344a60419764b4b1334d9e0285031c961176ff88ffb652405b0c/ruamel.yaml.clib-0.2.14-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a911aa73588d9a8b08d662b9484bc0567949529824a55d3885b77e8dd62a127a", size = 647404, upload-time = "2025-09-22T19:50:52.921Z" }, - { url = "https://files.pythonhosted.org/packages/df/99/65080c863eb06d4498de3d6c86f3e90595e02e159fd8529f1565f56cfe2c/ruamel.yaml.clib-0.2.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a05ba88adf3d7189a974b2de7a9d56731548d35dc0a822ec3dc669caa7019b29", size = 753141, upload-time = "2025-09-22T19:50:50.294Z" }, - { url = "https://files.pythonhosted.org/packages/3d/e3/0de85f3e3333f8e29e4b10244374a202a87665d1131798946ee22cf05c7c/ruamel.yaml.clib-0.2.14-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb04c5650de6668b853623eceadcdb1a9f2fee381f5d7b6bc842ee7c239eeec4", size = 703477, upload-time = "2025-09-22T19:50:51.508Z" }, - { url = "https://files.pythonhosted.org/packages/d9/25/0d2f09d8833c7fd77ab8efeff213093c16856479a9d293180a0d89f6bed9/ruamel.yaml.clib-0.2.14-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:df3ec9959241d07bc261f4983d25a1205ff37703faf42b474f15d54d88b4f8c9", size = 741157, upload-time = "2025-09-23T18:42:50.408Z" }, - { url = "https://files.pythonhosted.org/packages/d3/8c/959f10c2e2153cbdab834c46e6954b6dd9e3b109c8f8c0a3cf1618310985/ruamel.yaml.clib-0.2.14-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:fbc08c02e9b147a11dfcaa1ac8a83168b699863493e183f7c0c8b12850b7d259", size = 745859, upload-time = "2025-09-22T19:50:54.497Z" }, - { url = "https://files.pythonhosted.org/packages/ed/6b/e580a7c18b485e1a5f30a32cda96b20364b0ba649d9d2baaf72f8bd21f83/ruamel.yaml.clib-0.2.14-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c099cafc1834d3c5dac305865d04235f7c21c167c8dd31ebc3d6bbc357e2f023", size = 770200, upload-time = "2025-09-22T19:50:55.718Z" }, - { url = "https://files.pythonhosted.org/packages/ef/44/3455eebc761dc8e8fdced90f2b0a3fa61e32ba38b50de4130e2d57db0f21/ruamel.yaml.clib-0.2.14-cp312-cp312-win32.whl", hash = "sha256:b5b0f7e294700b615a3bcf6d28b26e6da94e8eba63b079f4ec92e9ba6c0d6b54", size = 98829, upload-time = "2025-09-22T19:50:58.895Z" }, - { url = "https://files.pythonhosted.org/packages/76/ab/5121f7f3b651db93de546f8c982c241397aad0a4765d793aca1dac5eadee/ruamel.yaml.clib-0.2.14-cp312-cp312-win_amd64.whl", hash = "sha256:a37f40a859b503304dd740686359fcf541d6fb3ff7fc10f539af7f7150917c68", size = 115570, upload-time = "2025-09-22T19:50:57.981Z" }, -] - -[[package]] -name = "scikit-learn" -version = "1.7.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "joblib" }, - { name = "numpy" }, - { name = "scipy" }, - { name = "threadpoolctl" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/41/84/5f4af978fff619706b8961accac84780a6d298d82a8873446f72edb4ead0/scikit_learn-1.7.1.tar.gz", hash = "sha256:24b3f1e976a4665aa74ee0fcaac2b8fccc6ae77c8e07ab25da3ba6d3292b9802", size = 7190445, upload-time = "2025-07-18T08:01:54.5Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/16/57f176585b35ed865f51b04117947fe20f130f78940c6477b6d66279c9c2/scikit_learn-1.7.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3cee419b49b5bbae8796ecd690f97aa412ef1674410c23fc3257c6b8b85b8087", size = 9260431, upload-time = "2025-07-18T08:01:22.77Z" }, - { url = "https://files.pythonhosted.org/packages/67/4e/899317092f5efcab0e9bc929e3391341cec8fb0e816c4789686770024580/scikit_learn-1.7.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2fd8b8d35817b0d9ebf0b576f7d5ffbbabdb55536b0655a8aaae629d7ffd2e1f", size = 8637191, upload-time = "2025-07-18T08:01:24.731Z" }, - { url = "https://files.pythonhosted.org/packages/f3/1b/998312db6d361ded1dd56b457ada371a8d8d77ca2195a7d18fd8a1736f21/scikit_learn-1.7.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:588410fa19a96a69763202f1d6b7b91d5d7a5d73be36e189bc6396bfb355bd87", size = 9486346, upload-time = "2025-07-18T08:01:26.713Z" }, - { url = "https://files.pythonhosted.org/packages/ad/09/a2aa0b4e644e5c4ede7006748f24e72863ba2ae71897fecfd832afea01b4/scikit_learn-1.7.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e3142f0abe1ad1d1c31a2ae987621e41f6b578144a911ff4ac94781a583adad7", size = 9290988, upload-time = "2025-07-18T08:01:28.938Z" }, - { url = "https://files.pythonhosted.org/packages/15/fa/c61a787e35f05f17fc10523f567677ec4eeee5f95aa4798dbbbcd9625617/scikit_learn-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:3ddd9092c1bd469acab337d87930067c87eac6bd544f8d5027430983f1e1ae88", size = 8735568, upload-time = "2025-07-18T08:01:30.936Z" }, -] - -[[package]] -name = "scipy" -version = "1.16.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f5/4a/b927028464795439faec8eaf0b03b011005c487bb2d07409f28bf30879c4/scipy-1.16.1.tar.gz", hash = "sha256:44c76f9e8b6e8e488a586190ab38016e4ed2f8a038af7cd3defa903c0a2238b3", size = 30580861, upload-time = "2025-07-27T16:33:30.834Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/d9/ec4864f5896232133f51382b54a08de91a9d1af7a76dfa372894026dfee2/scipy-1.16.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:81b433bbeaf35728dad619afc002db9b189e45eebe2cd676effe1fb93fef2b9c", size = 36575194, upload-time = "2025-07-27T16:27:41.321Z" }, - { url = "https://files.pythonhosted.org/packages/5c/6d/40e81ecfb688e9d25d34a847dca361982a6addf8e31f0957b1a54fbfa994/scipy-1.16.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:886cc81fdb4c6903a3bb0464047c25a6d1016fef77bb97949817d0c0d79f9e04", size = 28594590, upload-time = "2025-07-27T16:27:49.204Z" }, - { url = "https://files.pythonhosted.org/packages/0e/37/9f65178edfcc629377ce9a64fc09baebea18c80a9e57ae09a52edf84880b/scipy-1.16.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:15240c3aac087a522b4eaedb09f0ad061753c5eebf1ea430859e5bf8640d5919", size = 20866458, upload-time = "2025-07-27T16:27:54.98Z" }, - { url = "https://files.pythonhosted.org/packages/2c/7b/749a66766871ea4cb1d1ea10f27004db63023074c22abed51f22f09770e0/scipy-1.16.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:65f81a25805f3659b48126b5053d9e823d3215e4a63730b5e1671852a1705921", size = 23539318, upload-time = "2025-07-27T16:28:01.604Z" }, - { url = "https://files.pythonhosted.org/packages/c4/db/8d4afec60eb833a666434d4541a3151eedbf2494ea6d4d468cbe877f00cd/scipy-1.16.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6c62eea7f607f122069b9bad3f99489ddca1a5173bef8a0c75555d7488b6f725", size = 33292899, upload-time = "2025-07-27T16:28:09.147Z" }, - { url = "https://files.pythonhosted.org/packages/51/1e/79023ca3bbb13a015d7d2757ecca3b81293c663694c35d6541b4dca53e98/scipy-1.16.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f965bbf3235b01c776115ab18f092a95aa74c271a52577bcb0563e85738fd618", size = 35162637, upload-time = "2025-07-27T16:28:17.535Z" }, - { url = "https://files.pythonhosted.org/packages/b6/49/0648665f9c29fdaca4c679182eb972935b3b4f5ace41d323c32352f29816/scipy-1.16.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f006e323874ffd0b0b816d8c6a8e7f9a73d55ab3b8c3f72b752b226d0e3ac83d", size = 35490507, upload-time = "2025-07-27T16:28:25.705Z" }, - { url = "https://files.pythonhosted.org/packages/62/8f/66cbb9d6bbb18d8c658f774904f42a92078707a7c71e5347e8bf2f52bb89/scipy-1.16.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e8fd15fc5085ab4cca74cb91fe0a4263b1f32e4420761ddae531ad60934c2119", size = 37923998, upload-time = "2025-07-27T16:28:34.339Z" }, - { url = "https://files.pythonhosted.org/packages/14/c3/61f273ae550fbf1667675701112e380881905e28448c080b23b5a181df7c/scipy-1.16.1-cp312-cp312-win_amd64.whl", hash = "sha256:f7b8013c6c066609577d910d1a2a077021727af07b6fab0ee22c2f901f22352a", size = 38508060, upload-time = "2025-07-27T16:28:43.242Z" }, -] - -[[package]] -name = "seaborn" -version = "0.13.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "matplotlib" }, - { name = "numpy" }, - { name = "pandas" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/86/59/a451d7420a77ab0b98f7affa3a1d78a313d2f7281a57afb1a34bae8ab412/seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7", size = 1457696, upload-time = "2024-01-25T13:21:52.551Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" }, -] - -[[package]] -name = "sentry-sdk" -version = "2.35.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "urllib3" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/72/75/6223b9ffa0bf5a79ece08055469be73c18034e46ed082742a0899cc58351/sentry_sdk-2.35.1.tar.gz", hash = "sha256:241b41e059632fe1f7c54ae6e1b93af9456aebdfc297be9cf7ecfd6da5167e8e", size = 343145, upload-time = "2025-08-26T08:23:32.429Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/62/1f/5feb6c42cc30126e9574eabc28139f8c626b483a47c537f648d133628df0/sentry_sdk-2.35.1-py2.py3-none-any.whl", hash = "sha256:13b6d6cfdae65d61fe1396a061cf9113b20f0ec1bcb257f3826b88f01bb55720", size = 363887, upload-time = "2025-08-26T08:23:30.335Z" }, -] - -[[package]] -name = "setuptools" -version = "80.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, -] - -[[package]] -name = "six" -version = "1.17.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, -] - -[[package]] -name = "smmap" -version = "5.0.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329, upload-time = "2025-01-02T07:14:40.909Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, -] - -[[package]] -name = "submitit" -version = "1.5.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cloudpickle" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1f/2c/a824e3e03cbc4a48892014c5826fee7350994f4b6ae45c120e7713b2b7a1/submitit-1.5.3.tar.gz", hash = "sha256:d1cbc5d8859b519b1e47adc4aaa6001dcefe8a835f3032b151cb3de7d2841068", size = 81019, upload-time = "2025-05-21T09:06:42.331Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/db/80/90e0a0f4008f6572de58b042b1db9daced15d348a3586dda5efc9faba65e/submitit-1.5.3-py3-none-any.whl", hash = "sha256:ccc35100da12fe916541489deccccb6b9fa93dae8c01ade53e7f643552dc1795", size = 75463, upload-time = "2025-05-21T09:06:40.76Z" }, -] - -[[package]] -name = "sympy" -version = "1.13.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mpmath" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ca/99/5a5b6f19ff9f083671ddf7b9632028436167cd3d33e11015754e41b249a4/sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f", size = 7533040, upload-time = "2024-07-19T09:26:51.238Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177, upload-time = "2024-07-19T09:26:48.863Z" }, -] - -[[package]] -name = "termcolor" -version = "3.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/6c/3d75c196ac07ac8749600b60b03f4f6094d54e132c4d94ebac6ee0e0add0/termcolor-3.1.0.tar.gz", hash = "sha256:6a6dd7fbee581909eeec6a756cff1d7f7c376063b14e4a298dc4980309e55970", size = 14324, upload-time = "2025-04-30T11:37:53.791Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4f/bd/de8d508070629b6d84a30d01d57e4a65c69aa7f5abe7560b8fad3b50ea59/termcolor-3.1.0-py3-none-any.whl", hash = "sha256:591dd26b5c2ce03b9e43f391264626557873ce1d379019786f99b0c2bee140aa", size = 7684, upload-time = "2025-04-30T11:37:52.382Z" }, -] - -[[package]] -name = "threadpoolctl" -version = "3.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, -] - -[[package]] -name = "tiktoken" -version = "0.11.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "regex" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a7/86/ad0155a37c4f310935d5ac0b1ccf9bdb635dcb906e0a9a26b616dd55825a/tiktoken-0.11.0.tar.gz", hash = "sha256:3c518641aee1c52247c2b97e74d8d07d780092af79d5911a6ab5e79359d9b06a", size = 37648, upload-time = "2025-08-08T23:58:08.495Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/9e/eceddeffc169fc75fe0fd4f38471309f11cb1906f9b8aa39be4f5817df65/tiktoken-0.11.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fd9e6b23e860973cf9526544e220b223c60badf5b62e80a33509d6d40e6c8f5d", size = 1055199, upload-time = "2025-08-08T23:57:45.076Z" }, - { url = "https://files.pythonhosted.org/packages/4f/cf/5f02bfefffdc6b54e5094d2897bc80efd43050e5b09b576fd85936ee54bf/tiktoken-0.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6a76d53cee2da71ee2731c9caa747398762bda19d7f92665e882fef229cb0b5b", size = 996655, upload-time = "2025-08-08T23:57:46.304Z" }, - { url = "https://files.pythonhosted.org/packages/65/8e/c769b45ef379bc360c9978c4f6914c79fd432400a6733a8afc7ed7b0726a/tiktoken-0.11.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ef72aab3ea240646e642413cb363b73869fed4e604dcfd69eec63dc54d603e8", size = 1128867, upload-time = "2025-08-08T23:57:47.438Z" }, - { url = "https://files.pythonhosted.org/packages/d5/2d/4d77f6feb9292bfdd23d5813e442b3bba883f42d0ac78ef5fdc56873f756/tiktoken-0.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f929255c705efec7a28bf515e29dc74220b2f07544a8c81b8d69e8efc4578bd", size = 1183308, upload-time = "2025-08-08T23:57:48.566Z" }, - { url = "https://files.pythonhosted.org/packages/7a/65/7ff0a65d3bb0fc5a1fb6cc71b03e0f6e71a68c5eea230d1ff1ba3fd6df49/tiktoken-0.11.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:61f1d15822e4404953d499fd1dcc62817a12ae9fb1e4898033ec8fe3915fdf8e", size = 1244301, upload-time = "2025-08-08T23:57:49.642Z" }, - { url = "https://files.pythonhosted.org/packages/f5/6e/5b71578799b72e5bdcef206a214c3ce860d999d579a3b56e74a6c8989ee2/tiktoken-0.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:45927a71ab6643dfd3ef57d515a5db3d199137adf551f66453be098502838b0f", size = 884282, upload-time = "2025-08-08T23:57:50.759Z" }, -] - -[[package]] -name = "torch" -version = "2.6.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "filelock" }, - { name = "fsspec" }, - { name = "jinja2" }, - { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "setuptools" }, - { name = "sympy" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "typing-extensions" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9", size = 766624563, upload-time = "2025-01-29T16:23:19.084Z" }, - { url = "https://files.pythonhosted.org/packages/01/d6/455ab3fbb2c61c71c8842753b566012e1ed111e7a4c82e0e1c20d0c76b62/torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb", size = 95607867, upload-time = "2025-01-29T16:25:55.649Z" }, - { url = "https://files.pythonhosted.org/packages/18/cf/ae99bd066571656185be0d88ee70abc58467b76f2f7c8bfeb48735a71fe6/torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239", size = 204120469, upload-time = "2025-01-29T16:24:01.821Z" }, - { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538, upload-time = "2025-01-29T16:24:18.976Z" }, -] - -[[package]] -name = "torchcodec" -version = "0.6.0" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/b3/11326a0e7a3c803a95975cfce4ac88fa4ea1a0d432bb876081046c5a5554/torchcodec-0.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fba260145a239b5afe13336e3a5bc1b089c9c31a073e9a7c2026d4cbd853fdd9", size = 3482584, upload-time = "2025-08-07T08:51:32.535Z" }, - { url = "https://files.pythonhosted.org/packages/a7/d1/3f90561df013f6a015ef19de22726b64073fee405f53d3c4b8255ab05a67/torchcodec-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:fdef91a17fb1f1a159ce23710324a9a4e6d6a885275de73700f94a9ad562c6b2", size = 1370954, upload-time = "2025-08-07T08:51:15.021Z" }, -] - -[[package]] -name = "torchvision" -version = "0.21.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, - { name = "pillow" }, - { name = "torch" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/52/5b/76ca113a853b19c7b1da761f8a72cb6429b3bd0bf932537d8df4657f47c3/torchvision-0.21.0-1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:ffa2a16499508fe6798323e455f312c7c55f2a88901c9a7c0fb1efa86cf7e327", size = 2329878, upload-time = "2025-03-18T17:25:50.039Z" }, - { url = "https://files.pythonhosted.org/packages/6e/1b/28f527b22d5e8800184d0bc847f801ae92c7573a8c15979d92b7091c0751/torchvision-0.21.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:97a5814a93c793aaf0179cfc7f916024f4b63218929aee977b645633d074a49f", size = 1784140, upload-time = "2025-01-29T16:28:44.694Z" }, - { url = "https://files.pythonhosted.org/packages/36/63/0722e153fd27d64d5b0af45b5c8cb0e80b35a68cf0130303bc9a8bb095c7/torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b578bcad8a4083b40d34f689b19ca9f7c63e511758d806510ea03c29ac568f7b", size = 7238673, upload-time = "2025-01-29T16:28:27.631Z" }, - { url = "https://files.pythonhosted.org/packages/bb/ea/03541ed901cdc30b934f897060d09bbf7a98466a08ad1680320f9ce0cbe0/torchvision-0.21.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5083a5b1fec2351bf5ea9900a741d54086db75baec4b1d21e39451e00977f1b1", size = 14701186, upload-time = "2025-01-29T16:28:16.491Z" }, - { url = "https://files.pythonhosted.org/packages/4c/6a/c7752603060d076dfed95135b78b047dc71792630cbcb022e3693d6f32ef/torchvision-0.21.0-cp312-cp312-win_amd64.whl", hash = "sha256:6eb75d41e3bbfc2f7642d0abba9383cc9ae6c5a4ca8d6b00628c225e1eaa63b3", size = 1560520, upload-time = "2025-01-29T16:28:42.122Z" }, -] - -[[package]] -name = "tqdm" -version = "4.67.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, -] - -[[package]] -name = "triton" -version = "3.2.0" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c", size = 253159365, upload-time = "2025-01-22T19:13:24.648Z" }, -] - -[[package]] -name = "typing-extensions" -version = "4.15.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, -] - -[[package]] -name = "typing-inspection" -version = "0.4.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f8/b1/0c11f5058406b3af7609f121aaa6b609744687f1d158b3c3a5bf4cc94238/typing_inspection-0.4.1.tar.gz", hash = "sha256:6ae134cc0203c33377d43188d4064e9b357dba58cff3185f22924610e70a9d28", size = 75726, upload-time = "2025-05-21T18:55:23.885Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" }, -] - -[[package]] -name = "tzdata" -version = "2025.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, -] - -[[package]] -name = "urllib3" -version = "2.5.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, -] - -[[package]] -name = "urwid" -version = "3.0.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "wcwidth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/46/2d/71550379ed6b34968e14f73b0cf8574dee160acb6b820a066ab238ef2d4f/urwid-3.0.2.tar.gz", hash = "sha256:e7cb70ba1e7ff45779a5a57e43c57581ee7de6ceefb56c432491a4a6ce81eb78", size = 855353, upload-time = "2025-05-07T10:48:51.381Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b5/ee/2956918f14fd6e4310f7200108b53917f4da713d74c7ccd0d91a2e3a4f18/urwid-3.0.2-py3-none-any.whl", hash = "sha256:94ec1448d0178c881c01845c2b478cdc89f7b71bb65349466dbc99da1965eaac", size = 295994, upload-time = "2025-05-07T10:48:49.173Z" }, -] - -[[package]] -name = "urwid-readline" -version = "0.15.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "urwid" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ad/70/be318554495555eba7d8ff6e489f6f74ddb225b24086ba4af62a82e723fd/urwid_readline-0.15.1.tar.gz", hash = "sha256:9301444b86d58f7d26388506b704f142cefd193888488b4070d3a0fdfcfc0f84", size = 9007, upload-time = "2024-09-22T17:51:55.144Z" } - -[[package]] -name = "wandb" -version = "0.21.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "gitpython" }, - { name = "packaging" }, - { name = "platformdirs" }, - { name = "protobuf" }, - { name = "pydantic" }, - { name = "pyyaml" }, - { name = "requests" }, - { name = "sentry-sdk" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/2f/84/af6ccdf95e56f15aceb360e437fbfcca3dc91ad8ca335fe482083e29f7a5/wandb-0.21.3.tar.gz", hash = "sha256:031e24e2aad0ce735dfdcc74baf2f2c12c106f500ed24798de6ef9b9e63bb432", size = 40146972, upload-time = "2025-08-30T18:21:55.138Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/e8/b5bfbbc7f76c11fd0665b92be8a38c6a83b27f353552233b9959b21be488/wandb-0.21.3-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:f85bac45b4482742ec9ff190af38eb00a877ddeb4875475e7e487dc19300ff03", size = 18820209, upload-time = "2025-08-30T18:21:33.47Z" }, - { url = "https://files.pythonhosted.org/packages/59/a3/03f0fcde49609df1cb3a382fb5053f601b88da448bcd415ed7f75272eee7/wandb-0.21.3-py3-none-macosx_12_0_arm64.whl", hash = "sha256:8a2b3ba419b91d47edead2755f04cef54f9e3c4496ee0c9854c3cfeff4216dd3", size = 18310636, upload-time = "2025-08-30T18:21:37.405Z" }, - { url = "https://files.pythonhosted.org/packages/1d/c3/d6048db30ff2e3c67089ba0e94878572fd26137b146f8e3b27bbdf428b31/wandb-0.21.3-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:35a1972881f3b85755befab004118234593792a9f05e07fd6345780172f4420e", size = 19053277, upload-time = "2025-08-30T18:21:39.389Z" }, - { url = "https://files.pythonhosted.org/packages/ea/7f/805c3d2fa9e3b8b6bf2bc534887c9ed97bdf22007ca8ba59424a1c8bb360/wandb-0.21.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d9cf8588cb090a2a41f589037fda72c57c9e23edfbd2ad829e575f1305d942c", size = 18130850, upload-time = "2025-08-30T18:21:41.573Z" }, - { url = "https://files.pythonhosted.org/packages/5b/af/a3252e5afac98a036f83c65ec92cadf6677ccdaacbbb2151da29f694d136/wandb-0.21.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff24b6b8e0f9da840b6bd5c7f60b0a5507bd998db40c9c2d476f9a340bec8ed", size = 19570305, upload-time = "2025-08-30T18:21:43.811Z" }, - { url = "https://files.pythonhosted.org/packages/4d/f9/4404b5a24bfd4ba027c19d30152b0fc7ebca8c49b202dee6ecb7f316082c/wandb-0.21.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4975dec19e2b343e23ed6e60f7e1290120553719f82e87a22205bede758416ad", size = 18135806, upload-time = "2025-08-30T18:21:46.211Z" }, - { url = "https://files.pythonhosted.org/packages/ff/32/9580f42899e54f3d0b4ea619b6f6a54980a4e36fd0675d58c09f0a08d3f6/wandb-0.21.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:514a0aad40ecc0bdb757b1dc86e4ac98f61d2d760445b6e1f555291562320f2d", size = 19646760, upload-time = "2025-08-30T18:21:48.768Z" }, - { url = "https://files.pythonhosted.org/packages/75/d3/faa6ddb792a158c154fb704b25c96d0478e71eabf96e3f17529fb23b6894/wandb-0.21.3-py3-none-win32.whl", hash = "sha256:45aa3d8ad53c6ee06f37490d7a329ed7d0f5ca4dbd5d05bb0c01d5da22f14691", size = 18709408, upload-time = "2025-08-30T18:21:50.859Z" }, - { url = "https://files.pythonhosted.org/packages/d8/2d/7ef56e25f78786e59fefd9b19867c325f9686317d9f7b93b5cb340360a3e/wandb-0.21.3-py3-none-win_amd64.whl", hash = "sha256:56d5a5697766f552a9933d8c6a564202194768eb0389bd5f9fe9a99cd4cee41e", size = 18709411, upload-time = "2025-08-30T18:21:52.874Z" }, -] - -[[package]] -name = "wcwidth" -version = "0.2.13" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301, upload-time = "2024-01-06T02:10:57.829Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166, upload-time = "2024-01-06T02:10:55.763Z" }, -]