From 53e6e9e827f94ab74eb87bdb9155a3e6d0fa0e27 Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Fri, 27 Jun 2025 10:18:05 +0000 Subject: [PATCH 01/12] post training branch --- .gitignore | 3 +- POST_TRAINING_README.md | 231 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 233 insertions(+), 1 deletion(-) create mode 100644 POST_TRAINING_README.md diff --git a/.gitignore b/.gitignore index 7734c465..305fea71 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,8 @@ __pycache__/ # C extensions *.so - +*.mp4 +*.json* # Distribution / packaging .Python build/ diff --git a/POST_TRAINING_README.md b/POST_TRAINING_README.md new file mode 100644 index 00000000..90292166 --- /dev/null +++ b/POST_TRAINING_README.md @@ -0,0 +1,231 @@ +# Cosmos-Predict2 Post-Training Configuration Guide + +This guide provides a comprehensive overview of the post-training methodology, configuration system, and current settings for Cosmos-Predict2 models. + +## Overview + +Cosmos-Predict2 supports post-training of Video2World models using a flexible configuration system built on Hydra/OmegaConf. The training command used is: + +```bash +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} +``` + +## Configuration System Architecture + +The configuration system is hierarchically organized to support flexible experimentation: + +``` +cosmos_predict2/configs/base/ +├── config.py # Main configuration class and registration +├── defaults/ # Default configuration groups +│ ├── callbacks.py # Training callbacks (logging, monitoring) +│ ├── checkpoint.py # Checkpoint saving/loading settings +│ ├── data.py # Dataset and dataloader configurations +│ ├── ema.py # Exponential Moving Average settings +│ ├── model.py # Model architecture configurations +│ ├── optimizer.py # Optimizer configurations +│ └── scheduler.py # Learning rate scheduler configurations +└── experiment/ # Experiment-specific configurations + ├── groot.py # GR00T robot experiments + ├── cosmos_nemo_assets.py # Cosmos Nemo Asset experiments + └── agibot_head_center_fisheye_color.py # AgiBOT experiments +``` + +## Current Default Settings + +### Base Configuration (`config.py`) + +- **Trainer Type**: `ImaginaireTrainer` +- **Max Iterations**: 400,000 +- **Logging Frequency**: Every 10 iterations +- **Validation Frequency**: Every 100 iterations (disabled by default) +- **Project Name**: `cosmos_predict2` +- **Job Group**: `debug` + +### Model Configurations (`defaults/model.py`) + +#### Cosmos-Predict2-2B-Video2World (`predict2_video2world_fsdp_2b`) +- **Model Path**: `checkpoints/nvidia/Cosmos-Predict2-2B-Video2World/model-720p-16fps.pt` +- **Parallelism**: FSDP (Fully Sharded Data Parallel) +- **FSDP Shard Size**: 8 +- **High Sigma Ratio**: 0.05 + +#### Cosmos-Predict2-14B-Video2World (`predict2_video2world_fsdp_14b`) +- **Model Path**: `checkpoints/nvidia/Cosmos-Predict2-14B-Video2World/model-720p-16fps.pt` +- **Parallelism**: FSDP (Fully Sharded Data Parallel) +- **FSDP Shard Size**: 32 +- **High Sigma Ratio**: 0.05 + +### Optimizer Configuration (`defaults/optimizer.py`) + +#### FusedAdamW (`fusedadamw`) +- **Learning Rate**: 1e-4 (default, often overridden) +- **Weight Decay**: 0.1 +- **Betas**: [0.9, 0.99] +- **Epsilon**: 1e-8 +- **Master Weights**: Enabled +- **Capturable**: Enabled + +### Scheduler Configuration (`defaults/scheduler.py`) + +#### Lambda Linear Scheduler (`lambdalinear`) +- **Warm-up Steps**: [1000] +- **Cycle Lengths**: [10000000000000] (effectively infinite) +- **F Start**: [1.0e-6] +- **F Max**: [1.0] +- **F Min**: [1.0] + +#### Constant Scheduler (`constant`) +- **Rate**: 1.0 (constant throughout training) + +## Available Experiments + +### GR00T Experiments (`experiment/groot.py`) + +#### 2B Model GR00T Training +```bash +EXP=predict2_video2world_training_2b_groot_gr1_480 +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} +``` + +**Configuration:** +- **Learning Rate**: 2^(-14.5) ≈ 3.05e-5 +- **Scheduler**: Lambda Linear (f_max=0.2, f_min=0.1, warm_up=1000, cycle=100000) +- **Context Parallel Size**: 1 +- **Save Frequency**: Every 200 iterations +- **Video Size**: 432x768, 93 frames +- **Batch Size**: 1 per GPU + +#### 14B Model GR00T Training +```bash +EXP=predict2_video2world_training_14b_groot_gr1_480 +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} +``` + +**Configuration:** +- **Learning Rate**: 2^(-14.5) ≈ 3.05e-5 +- **Context Parallel Size**: 4 +- **FSDP Shard Size**: 32 +- **Other settings**: Same as 2B variant + +### Other Available Experiments +- **Cosmos Nemo Assets**: Various experiments with asset-based training +- **AgiBOT**: Fisheye camera training configurations + +## Key Training Parameters + +### Common Overrides in Experiments + +1. **Learning Rate**: Typically set to 2^(-14.5) ≈ 3.05e-5 +2. **Context Parallelism**: + - 2B models: 1-2 GPUs + - 14B models: 4 GPUs +3. **Guardrail**: Disabled during training (`guardrail_config.enabled=False`) +4. **EMA**: Often enabled for better model stability +5. **Validation**: Usually set to mock/disabled + +### Parallelism Strategy + +- **Data Parallelism**: Distributed across 8 GPUs (`--nproc_per_node=8`) +- **Model Parallelism**: FSDP with different shard sizes + - 2B: shard_size=8 + - 14B: shard_size=32 +- **Context Parallelism**: Used for longer sequences + - 2B: context_parallel_size=1-2 + - 14B: context_parallel_size=4 + +## Usage Examples + +### Basic Post-Training Command +```bash +# Set experiment name +export EXP=your_experiment_name + +# Run training +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} +``` + +### With LoRA Training +```bash +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} model.config.train_architecture=lora +``` + +### Environment Variables +```bash +# Disable NVTE fused attention if needed +export NVTE_FUSED_ATTN=0 + +# Set CUDA home +export CUDA_HOME=$CONDA_PREFIX +``` + +## Checkpoint Management + +### Save Locations +Checkpoints are saved to: `checkpoints/{PROJECT}/{GROUP}/{NAME}/checkpoints/` + +Example structure: +``` +checkpoints/posttraining/video2world/2b_groot_gr1_480/checkpoints/ +├── model/ +│ ├── iter_000000200.pt +│ ├── iter_000000400.pt +├── optim/ +├── scheduler/ +├── trainer/ +├── latest_checkpoint.txt +``` + +### Loading Checkpoints for Inference +```bash +python examples/video2world.py \ + --model_size 2B \ + --dit_path "checkpoints/posttraining/video2world/your_experiment/checkpoints/model/iter_001000.pt" \ + --prompt "Your descriptive prompt" \ + --input_path "path/to/input.jpg" \ + --save_path "path/to/output.mp4" +``` + +## Performance Considerations + +- **Memory Requirements**: + - 2B model: ~8 A100/H100 GPUs + - 14B model: ~8 A100/H100 GPUs with higher memory usage +- **Context Parallelism**: Used to handle longer video sequences +- **FSDP Sharding**: Enables training of large models across multiple GPUs + +## Custom Experiment Creation + +To create a new experiment: + +1. Define your dataset configuration +2. Create experiment config dictionary +3. Register with ConfigStore +4. Run with your experiment name + +Example template: +```python +your_custom_experiment = dict( + defaults=[ + {"override /model": "predict2_video2world_fsdp_2b"}, + {"override /optimizer": "fusedadamw"}, + {"override /scheduler": "lambdalinear"}, + {"override /ckpt_type": "standard"}, + {"override /data_val": "mock"}, + "_self_", + ], + job=dict( + project="your_project", + group="your_group", + name="your_experiment_name", + ), + # ... other configurations +) +``` + +## Troubleshooting + +- **OOM Errors**: Reduce batch size or increase context_parallel_size +- **Slow Training**: Check iter_speed callback hit_thres settings +- **Checkpoint Issues**: Verify checkpoint save paths and permissions +- **NVTE Issues**: Set `NVTE_FUSED_ATTN=0` environment variable \ No newline at end of file From 3ae0154d4fac66193452b9464a8ffb024f885496 Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Fri, 27 Jun 2025 13:58:07 +0000 Subject: [PATCH 02/12] Metropolis specific files --- .../configs/base/experiment/metropolis.py | 164 ++++++++++++++++++ extract_captions.sh | 25 +++ scripts/extract_captions.py | 39 +++++ setup_training_env.sh | 20 +++ 4 files changed, 248 insertions(+) create mode 100644 cosmos_predict2/configs/base/experiment/metropolis.py create mode 100644 extract_captions.sh create mode 100644 scripts/extract_captions.py create mode 100644 setup_training_env.sh diff --git a/cosmos_predict2/configs/base/experiment/metropolis.py b/cosmos_predict2/configs/base/experiment/metropolis.py new file mode 100644 index 00000000..c622c30e --- /dev/null +++ b/cosmos_predict2/configs/base/experiment/metropolis.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict2.data.dataset_video import Dataset +from imaginaire.lazy_config import LazyCall as L + + +def get_sampler(dataset) -> DistributedSampler: + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +# Metropolis dataset example +example_video_dataset_metropolis = L(Dataset)( + dataset_dir="datasets/metropolis/train", + num_frames=93, # Standard frame count, will be cropped/padded as needed + video_size=(740, 1280), # Height, Width format - matches your 1280x740 videos +) + +dataloader_train_metropolis = L(DataLoader)( + dataset=example_video_dataset_metropolis, + sampler=L(get_sampler)(dataset=example_video_dataset_metropolis), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +# torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=predict2_video2world_training_2b_metropolis +predict2_video2world_training_2b_metropolis = dict( + defaults=[ + {"override /model": "predict2_video2world_fsdp_2b"}, + {"override /optimizer": "fusedadamw"}, + {"override /scheduler": "lambdalinear"}, + {"override /ckpt_type": "standard"}, + {"override /data_val": "mock"}, + "_self_", + ], + job=dict( + project="posttraining", + group="video2world", + name="2b_metropolis", + ), + model=dict( + config=dict( + pipe_config=dict( + ema=dict(enabled=True), # Enable EMA for better stability + guardrail_config=dict(enabled=False), # Disable guardrail during training + ), + ) + ), + model_parallel=dict( + context_parallel_size=2, # Context parallelism for handling video sequences + ), + dataloader_train=dataloader_train_metropolis, + trainer=dict( + distributed_parallelism="fsdp", + callbacks=dict( + iter_speed=dict(hit_thres=10), # Report speed every 10 iterations + ), + max_iter=2000, # Increased iterations for larger dataset (1,191 samples) + ), + checkpoint=dict( + save_iter=500, # Save checkpoint every 500 iterations + ), + optimizer=dict( + lr=2 ** (-14.5), # Learning rate: ~3.05e-5 + ), + scheduler=dict( + warm_up_steps=[2_000], # Warm up steps + cycle_lengths=[400_000], # Cycle length + f_max=[0.6], # Maximum factor + f_min=[0.3], # Minimum factor + ), +) + +# Optional: 14B model configuration (requires more GPU memory) +predict2_video2world_training_14b_metropolis = dict( + defaults=[ + {"override /model": "predict2_video2world_fsdp_14b"}, + {"override /optimizer": "fusedadamw"}, + {"override /scheduler": "lambdalinear"}, + {"override /ckpt_type": "standard"}, + {"override /data_val": "mock"}, + "_self_", + ], + job=dict( + project="posttraining", + group="video2world", + name="14b_metropolis", + ), + model=dict( + config=dict( + pipe_config=dict( + ema=dict(enabled=True), + guardrail_config=dict(enabled=False), + ), + ) + ), + model_parallel=dict( + context_parallel_size=4, # Higher context parallelism for 14B model + ), + dataloader_train=dataloader_train_metropolis, + trainer=dict( + distributed_parallelism="fsdp", + callbacks=dict( + iter_speed=dict(hit_thres=10), + ), + max_iter=2000, + ), + checkpoint=dict( + save_iter=500, + ), + optimizer=dict( + lr=2 ** (-14.5), + weight_decay=0.2, + ), + scheduler=dict( + warm_up_steps=[2_000], + cycle_lengths=[40_000], + f_max=[0.25], + f_min=[0.1], + ), +) + +# Register configurations +for _item in [ + # 2b, metropolis + predict2_video2world_training_2b_metropolis, + # 14b, metropolis (optional) + predict2_video2world_training_14b_metropolis, +]: + # Get the experiment name from the global variable + experiment_name = [name.lower() for name, value in globals().items() if value is _item][0] + + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) \ No newline at end of file diff --git a/extract_captions.sh b/extract_captions.sh new file mode 100644 index 00000000..5c5df90b --- /dev/null +++ b/extract_captions.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Create output directory +mkdir -p /home/arslana/codes/github/cosmos-predict2/datasets/metropolis/train/metas + +# Source directory containing JSON files +SOURCE_DIR="/home/arslana/codes/github/cosmos-predict2/assets/watermarked/metas/v0" + +# Output directory for text files +OUTPUT_DIR="/home/arslana/codes/github/cosmos-predict2/datasets/metropolis/train/metas" + +# Process each JSON file +for json_file in "$SOURCE_DIR"/*.json; do + if [ -f "$json_file" ]; then + # Get base filename without extension + basename=$(basename "$json_file" .json) + + # Extract qwen_caption using jq and save to text file + jq -r '.windows[0].qwen_caption' "$json_file" > "$OUTPUT_DIR/${basename}.txt" + + echo "Processed: $basename" + fi +done + +echo "All captions extracted successfully!" \ No newline at end of file diff --git a/scripts/extract_captions.py b/scripts/extract_captions.py new file mode 100644 index 00000000..cb72ba55 --- /dev/null +++ b/scripts/extract_captions.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +import json +import os +from pathlib import Path + +# Define directories +source_dir = "/home/arslana/codes/github/cosmos-predict2/assets/watermarked/metas/v0" +output_dir = "/home/arslana/codes/github/cosmos-predict2/datasets/metropolis/train/metas" + +# Create output directory +os.makedirs(output_dir, exist_ok=True) + +# Process each JSON file +for json_file in Path(source_dir).glob("*.json"): + try: + # Read and parse JSON file + with open(json_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # Extract qwen_caption from first window + if 'windows' in data and len(data['windows']) > 0: + caption = data['windows'][0].get('qwen_caption', '') + + # Create output text file + output_file = Path(output_dir) / f"{json_file.stem}.txt" + + # Write caption to text file + with open(output_file, 'w', encoding='utf-8') as f: + f.write(caption) + + print(f"Processed: {json_file.stem}") + else: + print(f"Warning: No windows found in {json_file.name}") + + except Exception as e: + print(f"Error processing {json_file.name}: {e}") + +print("All captions extracted successfully!") \ No newline at end of file diff --git a/setup_training_env.sh b/setup_training_env.sh new file mode 100644 index 00000000..674a290d --- /dev/null +++ b/setup_training_env.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Set up environment variables for Cosmos-Predict2 training +# Recommended settings to avoid common issues + +# CUDA settings +export CUDA_HOME=$CONDA_PREFIX + +# NCCL settings (to avoid timeout issues) +export NCCL_TIMEOUT=3600 # 1 hour timeout +export NCCL_ASYNC_ERROR_HANDLING=1 +export NCCL_DEBUG=WARN # Set to INFO for more detailed logs + +# Optional: Disable NVTE fused attention if needed +# export NVTE_FUSED_ATTN=0 + +# PyTorch settings +export PYTHONPATH=$(pwd):$PYTHONPATH + +echo "Environment variables set for Cosmos-Predict2 training" \ No newline at end of file From 5a6ff5c20db195f59da43f6be702be2df22968ed Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Fri, 27 Jun 2025 14:03:44 +0000 Subject: [PATCH 03/12] Metropolis README --- README_metropolis.md | 331 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 331 insertions(+) create mode 100644 README_metropolis.md diff --git a/README_metropolis.md b/README_metropolis.md new file mode 100644 index 00000000..c8f2c578 --- /dev/null +++ b/README_metropolis.md @@ -0,0 +1,331 @@ +# Cosmos-Predict2 Post-Training: Metropolis Dataset + +This README provides complete instructions for post-training Cosmos-Predict2 Video2World models on the Metropolis traffic dataset. + +## Dataset Overview + +- **Dataset Name**: Metropolis Traffic Dataset +- **Location**: `/home/arslana/codes/github/cosmos-predict2/datasets/metropolis/train` +- **Size**: 1,191 video-text pairs +- **Video Format**: 1280×740 resolution, ~8.2 seconds duration, ~246 frames +- **Text Format**: Traffic scene descriptions extracted from qwen_caption + +## Dataset Structure + +``` +datasets/metropolis/train/ +├── metas/ # Text descriptions (.txt files) +├── videos/ # Video files (.mp4 files) +└── t5_xxl/ # Pre-computed T5-XXL embeddings (.pickle files) +``` + +## Configuration + +### Model Configuration +- **Model**: Cosmos-Predict2-2B-Video2World (FSDP) +- **Input Frames**: 93 (cropped/padded from original ~246 frames) +- **Video Size**: 740×1280 (Height×Width format) +- **Batch Size**: 1 per GPU +- **Context Parallelism**: 2 (2B model) / 4 (14B model) + +### Training Configuration +- **Learning Rate**: 3.05e-5 (2^-14.5) +- **Max Iterations**: 2,000 +- **Save Frequency**: Every 500 iterations +- **Optimizer**: FusedAdamW +- **Scheduler**: Lambda Linear +- **EMA**: Enabled for stability +- **Guardrails**: Disabled during training + +## Quick Start + +### Step 1: Set Up Environment +```bash +# Make environment setup script executable +chmod +x setup_training_env.sh + +# Source environment variables +source setup_training_env.sh +``` + +### Step 2: Run Training (2B Model) +```bash +# Set experiment name +export EXP=predict2_video2world_training_2b_metropolis + +# Run training +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} +``` + +### Step 3: Monitor Training Progress +Watch for: +- Iteration speed (~20-30 seconds per iteration) +- Loss decreasing over time +- No NCCL timeout errors +- GPU memory usage + +## Training Options + +### Option 1: Standard Training (2B Model) +```bash +source setup_training_env.sh +export EXP=predict2_video2world_training_2b_metropolis +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} +``` + +### Option 2: LoRA Training (Parameter-Efficient) +```bash +source setup_training_env.sh +export EXP=predict2_video2world_training_2b_metropolis +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} model.config.train_architecture=lora +``` + +### Option 3: 14B Model (Requires More GPU Memory) +```bash +source setup_training_env.sh +export EXP=predict2_video2world_training_14b_metropolis +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} +``` + +## File Structure After Training + +### Configuration Files +- `cosmos_predict2/configs/base/experiment/metropolis.py` - Training configuration +- `setup_training_env.sh` - Environment setup script + +### Checkpoint Structure +``` +checkpoints/posttraining/video2world/2b_metropolis/checkpoints/ +├── model/ +│ ├── iter_000000500.pt # Checkpoint at 500 iterations +│ ├── iter_000001000.pt # Checkpoint at 1000 iterations +│ ├── iter_000001500.pt # Checkpoint at 1500 iterations +│ ├── iter_000002000.pt # Final checkpoint +├── optim/ # Optimizer states +├── scheduler/ # Scheduler states +├── trainer/ # Trainer states +└── latest_checkpoint.txt # Points to latest checkpoint +``` + +## Training Parameters Details + +### Dataset Configuration +```python +example_video_dataset_metropolis = L(Dataset)( + dataset_dir="datasets/metropolis/train", + num_frames=93, # Standard frame count + video_size=(740, 1280), # Height, Width format +) +``` + +### DataLoader Configuration +```python +dataloader_train_metropolis = L(DataLoader)( + dataset=example_video_dataset_metropolis, + sampler=L(get_sampler)(dataset=example_video_dataset_metropolis), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +``` + +### Training Configuration +```python +trainer=dict( + distributed_parallelism="fsdp", + max_iter=2000, # Total training iterations + callbacks=dict( + iter_speed=dict(hit_thres=10), # Report every 10 iterations + ), +) +``` + +### Optimizer Configuration +```python +optimizer=dict( + lr=2 ** (-14.5), # Learning rate: ~3.05e-5 +) +``` + +### Scheduler Configuration +```python +scheduler=dict( + warm_up_steps=[2_000], # Warm-up steps + cycle_lengths=[400_000], # Cycle length + f_max=[0.6], # Maximum factor + f_min=[0.3], # Minimum factor +) +``` + +## Environment Variables + +The `setup_training_env.sh` script sets these important variables: + +```bash +export CUDA_HOME=$CONDA_PREFIX # CUDA path +export NCCL_TIMEOUT=3600 # 1 hour timeout +export NCCL_ASYNC_ERROR_HANDLING=1 # Better error handling +export NCCL_DEBUG=WARN # NCCL debug level +export PYTHONPATH=$(pwd):$PYTHONPATH # Python path +``` + +## Inference After Training + +### Test Your Trained Model +```bash +# Test with a sample from your dataset +python examples/video2world.py \ + --model_size 2B \ + --dit_path "checkpoints/posttraining/video2world/2b_metropolis/checkpoints/model/iter_000002000.pt" \ + --prompt "A traffic scene with vehicles navigating through an urban intersection" \ + --input_path "datasets/metropolis/train/videos/000f9c15-06af-58de-833b-606042201f53.mp4" \ + --save_path "results/metropolis_test_output.mp4" +``` + +### Custom Inference +```bash +# Use your own input image/video and prompt +python examples/video2world.py \ + --model_size 2B \ + --dit_path "checkpoints/posttraining/video2world/2b_metropolis/checkpoints/model/iter_000002000.pt" \ + --prompt "Your custom traffic scene description" \ + --input_path "path/to/your/input.jpg" \ + --save_path "path/to/your/output.mp4" +``` + +## Performance Expectations + +### Hardware Requirements +- **2B Model**: 8× A100/H100 GPUs (recommended) +- **14B Model**: 8× A100/H100 GPUs with higher memory +- **Memory**: ~50-70GB GPU memory per GPU during training +- **Storage**: Ensure sufficient space for checkpoints (~2-5GB per checkpoint) + +### Training Time Estimates +- **2B Model**: ~11-17 hours for 2,000 iterations (based on 20-30 sec/iter) +- **14B Model**: ~17-28 hours for 2,000 iterations +- **Checkpointing**: Additional time for saving every 500 iterations + +### Expected Metrics +- **Iteration Speed**: 20-30 seconds per iteration +- **Loss**: Should decrease from ~0.5-1.0 to ~0.1-0.3 +- **GPU Utilization**: ~90-100% +- **Memory Usage**: ~50-70GB per GPU + +## Troubleshooting + +### Common Issues and Solutions + +#### 1. NCCL Timeout Errors +```bash +# Already handled in setup_training_env.sh, but if issues persist: +export NCCL_TIMEOUT=7200 # Increase to 2 hours +export NCCL_DEBUG=INFO # Get more detailed logs +``` + +#### 2. Out of Memory (OOM) Errors +```bash +# Use LoRA training instead +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} model.config.train_architecture=lora + +# Or reduce context parallelism in config: +# context_parallel_size=1 # Instead of 2 +``` + +#### 3. Configuration Not Found +```bash +# Verify the metropolis.py file exists: +ls -la cosmos_predict2/configs/base/experiment/metropolis.py + +# Check for syntax errors: +python -m py_compile cosmos_predict2/configs/base/experiment/metropolis.py +``` + +#### 4. Dataset Loading Issues +```bash +# Verify dataset structure: +ls -la datasets/metropolis/train/ +ls datasets/metropolis/train/videos/ | wc -l # Should show 1191 +ls datasets/metropolis/train/metas/ | wc -l # Should show 1191 +ls datasets/metropolis/train/t5_xxl/ | wc -l # Should show 1191 +``` + +#### 5. Slow Training +```bash +# Check GPU utilization: +nvidia-smi -l 1 + +# Check I/O performance: +iotop -ao + +# Increase number of workers if CPU/I/O bound: +# num_workers=16 # Instead of 8 in config +``` + +## Monitoring Training + +### Key Logs to Watch +```bash +# Monitor training progress +tail -f + +# Key indicators: +# - "iter_speed X.XX seconds per iteration" +# - "Loss: X.XXXX" (should decrease) +# - "DeviceMonitor Stats" (GPU usage) +# - No NCCL timeout errors +``` + +### Training Success Indicators +- ✅ Consistent iteration speed (20-30 sec/iter) +- ✅ Decreasing loss values +- ✅ Regular checkpoint saves +- ✅ No NCCL communication errors +- ✅ High GPU utilization (>90%) + +## Customization Options + +### Modify Training Parameters +Edit `cosmos_predict2/configs/base/experiment/metropolis.py`: + +```python +# Change training iterations +max_iter=5000, # Instead of 2000 + +# Change learning rate +lr=2 ** (-15), # Lower learning rate + +# Change save frequency +save_iter=1000, # Save every 1000 iterations + +# Change batch size (if memory allows) +batch_size=2, # Instead of 1 +``` + +### Add Custom Callbacks +```python +# In the training config +callbacks=dict( + iter_speed=dict(hit_thres=10), + # Add custom callbacks here +), +``` + +## Next Steps + +1. **Monitor Training**: Watch logs for progress and issues +2. **Evaluate Results**: Test inference after training completes +3. **Fine-tune Parameters**: Adjust learning rate, iterations based on results +4. **Scale Up**: Try 14B model if 2B results are promising +5. **Production**: Deploy trained model for traffic scene generation + +## Support + +For issues specific to this Metropolis dataset setup: +1. Check this README troubleshooting section +2. Verify dataset integrity and structure +3. Check environment variables and dependencies +4. Monitor GPU resources and NCCL communications + +For general Cosmos-Predict2 issues, refer to the main documentation in `/documentations/`. \ No newline at end of file From c8a9785ce1e1560fa49fd24c456ad283aa0f922d Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Tue, 1 Jul 2025 11:27:39 +0000 Subject: [PATCH 04/12] path updates --- cosmos_predict2/configs/base/experiment/metropolis.py | 2 +- setup_training_env.sh | 0 2 files changed, 1 insertion(+), 1 deletion(-) mode change 100644 => 100755 setup_training_env.sh diff --git a/cosmos_predict2/configs/base/experiment/metropolis.py b/cosmos_predict2/configs/base/experiment/metropolis.py index c622c30e..9f176688 100644 --- a/cosmos_predict2/configs/base/experiment/metropolis.py +++ b/cosmos_predict2/configs/base/experiment/metropolis.py @@ -35,7 +35,7 @@ def get_sampler(dataset) -> DistributedSampler: # Metropolis dataset example example_video_dataset_metropolis = L(Dataset)( - dataset_dir="datasets/metropolis/train", + dataset_dir="datasets/watermark/train", num_frames=93, # Standard frame count, will be cropped/padded as needed video_size=(740, 1280), # Height, Width format - matches your 1280x740 videos ) diff --git a/setup_training_env.sh b/setup_training_env.sh old mode 100644 new mode 100755 From bfd4c070048c3c49187f016d8d5da439c197a479 Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Tue, 1 Jul 2025 13:11:55 +0000 Subject: [PATCH 05/12] wandb and 14b model parameters --- .gitignore | 2 + README_metropolis.md | 10 ++++- cosmos_predict2/callbacks/loss_log.py | 54 +++++++++++++++++++++++++++ scripts/train.py | 22 +++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 305fea71..0cd30444 100644 --- a/.gitignore +++ b/.gitignore @@ -79,6 +79,8 @@ target/ # Jupyter Notebook .ipynb_checkpoints +/wandb/* + # IPython profile_default/ ipython_config.py diff --git a/README_metropolis.md b/README_metropolis.md index c8f2c578..4f4c5679 100644 --- a/README_metropolis.md +++ b/README_metropolis.md @@ -83,8 +83,16 @@ torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos ### Option 3: 14B Model (Requires More GPU Memory) ```bash source setup_training_env.sh +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export EXP=predict2_video2world_training_14b_metropolis -torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} +export WANDB_ENTITY=nvidia-dir +export WANDB_API_KEY=d06eb702b1a5d31496a0ab288f8c84e2ae55510d +export WANDB_PROJECT="cosmos_predict2" +export WANDB_RUN_NAME="14b_metropolis_reduced_res" +export WANDB_MODE="online" + +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True && export EXP=predict2_video2world_training_14b_metropolis && torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} dataloader_train.dataset.video_size=[480,640] + ``` ## File Structure After Training diff --git a/cosmos_predict2/callbacks/loss_log.py b/cosmos_predict2/callbacks/loss_log.py index 4a4adbd9..3c7e6210 100644 --- a/cosmos_predict2/callbacks/loss_log.py +++ b/cosmos_predict2/callbacks/loss_log.py @@ -18,6 +18,7 @@ import torch import torch.distributed as dist +import wandb from imaginaire.model import ImaginaireModel from imaginaire.utils import distributed, log @@ -54,6 +55,7 @@ def __init__( self.name = self.__class__.__name__ self.train_video_log = _LossRecord() + self.val_video_log = _LossRecord() def on_before_backward( self, @@ -65,7 +67,13 @@ def on_before_backward( if iteration % (self.config.trainer.logging_iter * self.logging_iter_multipler) == 0 and distributed.is_rank0(): info = { "train_loss_step": loss.detach().item(), + "iteration": iteration, } + # Log to wandb if available + if wandb.run: + wandb.log(info, step=iteration) + # Also log to console + log.info(f"Iteration {iteration}: train_loss_step = {loss.detach().item():.6f}") def on_training_step_end( self, @@ -98,3 +106,49 @@ def on_training_step_end( info = {} if iter_count > 0: info[f"train@{self.logging_iter_multipler}/loss"] = loss + info[f"train@{self.logging_iter_multipler}/iter_count"] = iter_count + info["iteration"] = iteration + + # Log to wandb if available + if wandb.run: + wandb.log(info, step=iteration) + # Also log to console + log.info(f"Iteration {iteration}: train_loss_avg = {loss:.6f}, iter_count = {iter_count}") + + def on_validation_step_end( + self, + model: ImaginaireModel, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ): + """Log validation loss at each validation step.""" + if not (torch.isnan(loss) or torch.isinf(loss)): + _loss = output_batch.get("loss", loss).detach().mean(dim=0) + self.val_video_log.iter_count += 1 + self.val_video_log.loss += _loss + + def on_validation_end( + self, + model: ImaginaireModel, + iteration: int = 0, + ): + """Log averaged validation loss at the end of validation.""" + if distributed.is_rank0(): + world_size = dist.get_world_size() + loss, iter_count = self.val_video_log.get_stat() + iter_count *= world_size + + if iter_count > 0: + info = { + "val/loss": loss, + "val/iter_count": iter_count, + "iteration": iteration, + } + + # Log to wandb if available + if wandb.run: + wandb.log(info, step=iteration) + # Also log to console + log.info(f"Validation at iteration {iteration}: val_loss = {loss:.6f}, iter_count = {iter_count}") diff --git a/scripts/train.py b/scripts/train.py index 2b93d684..121c5472 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -18,6 +18,7 @@ import os from loguru import logger as logging +import wandb from imaginaire.config import Config, pretty_print_overrides from imaginaire.lazy_config import instantiate @@ -33,6 +34,22 @@ def launch(config: Config, args: argparse.Namespace) -> None: # check doesn't actually do anything. distributed.init() + # Initialize wandb for logging (only on rank 0) + if distributed.is_rank0(): + wandb.init( + project=config.job.project, + group=config.job.group, + name=config.job.name, + config={ + "max_iter": config.trainer.max_iter, + "logging_iter": config.trainer.logging_iter, + "validation_iter": config.trainer.validation_iter, + "learning_rate": getattr(config.optimizer, "lr", "unknown"), + "batch_size": getattr(config.dataloader_train, "batch_size", "unknown"), + } + ) + logging.info(f"Initialized wandb with project: {config.job.project}, group: {config.job.group}, name: {config.job.name}") + # Check that the config is valid config.validate() # Freeze the config so developers don't change it during training. @@ -49,6 +66,11 @@ def launch(config: Config, args: argparse.Namespace) -> None: dataloader_train, dataloader_val, ) + + # Clean up wandb + if distributed.is_rank0() and wandb.run: + wandb.finish() + logging.info("Finished wandb logging") if __name__ == "__main__": From 539520528511b34aade197eaafc73681e633aaab Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Thu, 3 Jul 2025 15:55:03 +0000 Subject: [PATCH 06/12] Metropolis batch inference --- customers/metropolis/run_batch_metropolis.py | 142 +++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 customers/metropolis/run_batch_metropolis.py diff --git a/customers/metropolis/run_batch_metropolis.py b/customers/metropolis/run_batch_metropolis.py new file mode 100644 index 00000000..b3077e12 --- /dev/null +++ b/customers/metropolis/run_batch_metropolis.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# run_batch_metropolis.py +""" +Batch-launch video2world jobs for the Metropolis benchmark. + +Automatically processes all frames from datasets/metropolis/benchmark/v2/frames_480_640 +and matches them with corresponding prompts from datasets/metropolis/benchmark/v2/prompts. + +Assumes: + • torchrun is in PATH + • NUM_GPUS is set below (or override via env) + • Frame files are .jpg and prompt files are .txt with matching base names +""" + +import os +import shlex +import subprocess +from pathlib import Path +import glob + +# ------------------------------------------------------------------ +# CONFIG – adjust only if your paths / flags change +# ------------------------------------------------------------------ +FRAMES_DIR = Path("datasets/metropolis/benchmark/v2/frames_480_640") +PROMPTS_DIR = Path("datasets/metropolis/benchmark/v2/prompts") +OUTPUT_DIR = Path("output") +NUM_GPUS = int(os.getenv("NUM_GPUS", 8)) +MODEL_SIZE = "14B" # Changed from 14B to 2B based on the checkpoint path in original +DIT_PATH = ( + "checkpoints/posttraining/video2world/14b_metropolis/" + "checkpoints/model/iter_000002000.pt" +) +TORCHRUN = "torchrun" # or full path if needed +SCRIPT = "examples/video2world.py" # entry script + +# ------------------------------------------------------------------ +def load_prompt(prompt_file: Path) -> str: + """Load prompt text from file, returning first non-empty line.""" + try: + with prompt_file.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: # Return first non-empty line + return line + return "" # Return empty string if no non-empty lines found + except Exception as e: + print(f"Warning: Could not read prompt file {prompt_file}: {e}") + return "" + +def main() -> None: + if not FRAMES_DIR.exists(): + raise FileNotFoundError(f"Frames directory not found: {FRAMES_DIR}") + + if not PROMPTS_DIR.exists(): + raise FileNotFoundError(f"Prompts directory not found: {PROMPTS_DIR}") + + # Create output directory + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + # Find all JPG files in the frames directory + frame_files = list(FRAMES_DIR.glob("*.jpg")) + + if not frame_files: + raise FileNotFoundError(f"No .jpg files found in {FRAMES_DIR}") + + print(f"Found {len(frame_files)} frame files to process") + + # Process each frame file + successful_jobs = 0 + failed_jobs = [] + + for idx, frame_file in enumerate(sorted(frame_files), 1): + # Get base name without extension + base_name = frame_file.stem + + # Find corresponding prompt file + prompt_file = PROMPTS_DIR / f"{base_name}.txt" + + if not prompt_file.exists(): + print(f"Warning: No prompt file found for {base_name}, skipping...") + failed_jobs.append(f"{base_name} - missing prompt file") + continue + + # Load prompt + prompt = load_prompt(prompt_file) + if not prompt: + print(f"Warning: Empty or unreadable prompt for {base_name}, skipping...") + failed_jobs.append(f"{base_name} - empty/unreadable prompt") + continue + + # Generate output path + output_file = OUTPUT_DIR / f"{base_name}.mp4" + + # Compose the torchrun command + cmd = [ + TORCHRUN, + f"--nproc_per_node={NUM_GPUS}", + SCRIPT, + "--model_size", MODEL_SIZE, + "--dit_path", DIT_PATH, + "--input_path", str(frame_file), + "--prompt", prompt, + "--save_path", str(output_file), + "--num_gpus", str(NUM_GPUS), + "--offload_guardrail", + "--offload_prompt_refiner", + ] + + # Print nicely for the log + pretty = " \\\n ".join(map(shlex.quote, cmd)) + border = f"[{idx:02}/{len(frame_files)}]" + print(f"\n{border} Processing: {base_name}") + print(f"Input: {frame_file}") + print(f"Prompt: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + print(f"Output: {output_file}") + print(f"Command:\n {pretty}\n") + + # Launch + try: + subprocess.run(cmd, check=True) + successful_jobs += 1 + print(f"✓ Successfully completed {base_name}") + except subprocess.CalledProcessError as e: + print(f"✗ Failed to process {base_name}: {e}") + failed_jobs.append(f"{base_name} - subprocess error: {e}") + continue + + # Print summary + print(f"\n{'='*60}") + print(f"BATCH PROCESSING COMPLETE") + print(f"{'='*60}") + print(f"Total files found: {len(frame_files)}") + print(f"Successfully processed: {successful_jobs}") + print(f"Failed: {len(failed_jobs)}") + + if failed_jobs: + print(f"\nFailed jobs:") + for failure in failed_jobs: + print(f" - {failure}") + +if __name__ == "__main__": + main() From d6920595336fc8652c6c87c419fbc32e44f5b7bc Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Fri, 4 Jul 2025 10:39:46 +0000 Subject: [PATCH 07/12] higher training iterations --- cosmos_predict2/configs/base/experiment/metropolis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos_predict2/configs/base/experiment/metropolis.py b/cosmos_predict2/configs/base/experiment/metropolis.py index 9f176688..eaa860c3 100644 --- a/cosmos_predict2/configs/base/experiment/metropolis.py +++ b/cosmos_predict2/configs/base/experiment/metropolis.py @@ -81,7 +81,7 @@ def get_sampler(dataset) -> DistributedSampler: callbacks=dict( iter_speed=dict(hit_thres=10), # Report speed every 10 iterations ), - max_iter=2000, # Increased iterations for larger dataset (1,191 samples) + max_iter=10000, # Increased iterations for larger dataset (1,191 samples) ), checkpoint=dict( save_iter=500, # Save checkpoint every 500 iterations From 2b452e12a79a2f99238e333bfea69c9c866daae4 Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Fri, 4 Jul 2025 11:58:40 +0000 Subject: [PATCH 08/12] Lora scripts --- documentations/inference_video2world_lora.md | 227 +++++++ examples/lora_inference_example.sh | 25 + examples/video2world_lora.py | 613 +++++++++++++++++++ 3 files changed, 865 insertions(+) create mode 100644 documentations/inference_video2world_lora.md create mode 100755 examples/lora_inference_example.sh create mode 100644 examples/video2world_lora.py diff --git a/documentations/inference_video2world_lora.md b/documentations/inference_video2world_lora.md new file mode 100644 index 00000000..74df0d88 --- /dev/null +++ b/documentations/inference_video2world_lora.md @@ -0,0 +1,227 @@ +# LoRA Inference for Video2World Generation + +This document describes how to perform inference with LoRA (Low-Rank Adaptation) post-trained models in Cosmos Predict2. + +## Overview + +LoRA is a parameter-efficient fine-tuning technique that allows you to adapt large pre-trained models to specific domains or tasks by training only a small number of additional parameters. The `video2world_lora.py` script enables inference with LoRA-trained checkpoints. + +## Key Differences from Standard Inference + +### Standard Inference (`video2world.py`) +- Loads the full model checkpoint +- Uses all model parameters as they were trained +- Suitable for base models or fully fine-tuned models + +### LoRA Inference (`video2world_lora.py`) +- Loads base model + LoRA adapters +- Only LoRA parameters were trained during post-training +- Memory efficient and faster training +- Maintains base model capabilities while adding domain-specific adaptations + +## Prerequisites + +1. **Post-trained LoRA checkpoint**: You need a checkpoint created with LoRA training (using `model.config.train_architecture=lora`) +2. **PEFT library**: The script uses the PEFT (Parameter-Efficient Fine-Tuning) library +3. **Base model**: The underlying base model should be available (automatically handled by the script) + +## Usage + +### Basic Command Structure + +```bash +export NUM_GPUS=8 +export PYTHONPATH=$(pwd) + +torchrun --nproc_per_node=${NUM_GPUS} examples/video2world_lora.py \ + --model_size 2B \ + --dit_path \ + --input_path \ + --prompt "" \ + --save_path \ + --use_lora \ + --lora_rank 16 \ + --lora_alpha 16 \ + --lora_target_modules "q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2" +``` + +### LoRA-Specific Parameters + +| Parameter | Description | Default | Notes | +|-----------|-------------|---------|-------| +| `--use_lora` | Enable LoRA inference mode | False | **Required** for LoRA inference | +| `--lora_rank` | Rank of LoRA adaptation | 16 | Should match training config | +| `--lora_alpha` | LoRA alpha parameter | 16 | Should match training config | +| `--lora_target_modules` | Target modules for LoRA | "q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2" | Should match training config | +| `--init_lora_weights` | Initialize LoRA weights | True | Usually keep as True | + +### Example: Metropolis Post-Training + +```bash +#!/bin/bash + +export NUM_GPUS=8 +export PYTHONPATH=$(pwd) + +torchrun --nproc_per_node=${NUM_GPUS} examples/video2world_lora.py \ + --model_size 2B \ + --dit_path checkpoints/posttraining/video2world/2b_metropolis/checkpoints/model/iter_000002000.pt \ + --input_path "/path/to/your/input.jpg" \ + --prompt "A static traffic cctv camera captures an urban intersection where multiple vehicles navigate through the busy streets." \ + --save_path results/2b_metropolis_lora/generated_video.mp4 \ + --num_gpus ${NUM_GPUS} \ + --offload_guardrail \ + --offload_prompt_refiner \ + --use_lora \ + --lora_rank 16 \ + --lora_alpha 16 \ + --lora_target_modules "q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2" +``` + +## How It Works + +### 1. Model Initialization +The script initializes the base model architecture without loading any weights: + +```python +pipe.dit = instantiate(dit_config).eval() +``` + +### 2. LoRA Addition +LoRA adapters are added to the specified target modules: + +```python +pipe.dit = add_lora_to_model( + pipe.dit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, +) +``` + +### 3. Checkpoint Loading +The LoRA-trained checkpoint is loaded with `strict=False` to accommodate the additional LoRA parameters: + +```python +missing_keys = pipe.dit.load_state_dict(state_dict_dit_regular, strict=False, assign=True) +``` + +### 4. Parameter Statistics +The script reports parameter counts to show the efficiency of LoRA: + +``` +Total parameters: 2,345,678,901 +Trainable LoRA parameters: 1,234,567 +LoRA parameter ratio: 0.05% +``` + +## LoRA Configuration Matching + +⚠️ **Important**: The LoRA parameters used during inference **must match** those used during training: + +### From Training Command +```bash +torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train \ + --config=cosmos_predict2/configs/base/config.py \ + -- experiment=predict2_video2world_training_2b_metropolis \ + model.config.train_architecture=lora +``` + +### Default LoRA Parameters (in `video2world_model.py`) +```python +lora_rank: int = 16 +lora_alpha: int = 16 +lora_target_modules: str = "q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2" +init_lora_weights: bool = True +``` + +## Troubleshooting + +### Common Issues + +1. **Missing LoRA Parameters** + ``` + Error: Missing keys in regular model: ['blocks.0.self_attn.q_proj.lora_A.default.weight', ...] + ``` + **Solution**: Ensure `--use_lora` flag is set and LoRA parameters match training configuration. + +2. **Unexpected Keys** + ``` + Warning: Unexpected keys in regular model: ['blocks.0.self_attn.q_proj.lora_B.default.weight', ...] + ``` + **Solution**: This usually means the checkpoint contains LoRA weights but `--use_lora` is not enabled. + +3. **Parameter Mismatch** + ``` + Error: size mismatch for blocks.0.self_attn.q_proj.lora_A.default.weight + ``` + **Solution**: Check that `--lora_rank` matches the value used during training. + +4. **Module Not Found** + ``` + Warning: No module named 'blocks.0.self_attn.k_proj' found for LoRA + ``` + **Solution**: Verify `--lora_target_modules` matches the training configuration. + +### Debug Mode + +To debug LoRA loading, you can check the model structure: + +```python +# Print all parameter names +for name, param in pipe.dit.named_parameters(): + if 'lora' in name.lower(): + print(f"LoRA parameter: {name}, shape: {param.shape}, requires_grad: {param.requires_grad}") +``` + +## Performance Considerations + +### Memory Usage +- LoRA inference uses slightly more memory than base model inference due to additional LoRA parameters +- Memory overhead is typically < 1% of base model size + +### Speed +- LoRA inference speed is comparable to base model inference +- Additional computation from LoRA layers is minimal + +### Quality +- LoRA models retain base model capabilities while adding domain-specific improvements +- Quality depends on the quality of post-training data and LoRA configuration + +## Comparison with Standard Inference + +| Aspect | Standard Inference | LoRA Inference | +|--------|-------------------|----------------| +| **Checkpoint Size** | Full model (~5-50GB) | Base + LoRA (~5-50GB + 10-100MB) | +| **Training Time** | Full model retraining | Only LoRA parameters | +| **Memory Usage** | Base memory | Base + small LoRA overhead | +| **Flexibility** | Fixed model | Can switch LoRA adapters | +| **Domain Adaptation** | Requires full fine-tuning | Efficient adaptation | + +## Best Practices + +1. **Always match LoRA parameters** between training and inference +2. **Use same target modules** as specified in training configuration +3. **Monitor parameter counts** to ensure LoRA is loaded correctly +4. **Test with known inputs** to validate model behavior +5. **Keep LoRA parameters small** (rank 8-64) for efficiency + +## Integration with Existing Workflows + +The LoRA inference script is fully compatible with existing workflows: + +- **Batch processing**: Use `--batch_input_json` for multiple inputs +- **Guardrails**: All guardrail functionality works as normal +- **Prompt refinement**: Compatible with prompt refiner +- **Multi-GPU**: Supports context parallelism for large videos +- **Different resolutions**: Support for 480p and 720p outputs + +## Future Enhancements + +Potential improvements for LoRA inference: + +1. **Multiple LoRA adapters**: Load multiple LoRA adapters for different domains +2. **Dynamic LoRA switching**: Switch between different LoRA adapters during inference +3. **LoRA composition**: Combine multiple LoRA adapters for multi-domain capabilities +4. **Quantized LoRA**: Support for quantized LoRA parameters for further efficiency \ No newline at end of file diff --git a/examples/lora_inference_example.sh b/examples/lora_inference_example.sh new file mode 100755 index 00000000..73eaaf4c --- /dev/null +++ b/examples/lora_inference_example.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Example script for running LoRA inference with Cosmos Predict2 +# This script demonstrates how to use the post-trained LoRA checkpoint + +# Set environment variables + export NUM_GPUS=8 + export PYTHONPATH=$(pwd) + +# Run LoRA inference with your post-trained checkpoint +torchrun --nproc_per_node=${NUM_GPUS} examples/video2world_lora.py \ + --model_size 2B \ + --dit_path checkpoints/posttraining/video2world/2b_metropolis/checkpoints/model/iter_000002000.pt \ + --input_path "/home/arslana/codes/cosmos-predict2/datasets/metropolis/benchmark/v2/frames/0a2cac1d-96ef-5132-9b34-67921af97ffb.jpg" \ + --prompt "A static traffic cctv camera captures an urban intersection where multiple vehicles navigate through the busy streets. A motorcycle enters the scene, and it crashes with a pedestrian which is crossing the street. Other cars in the scene drives normally without interruption." \ + --save_path results/2b_metropolis_lora/generated_test_lora.mp4 \ + --num_gpus ${NUM_GPUS} \ + --offload_guardrail \ + --offload_prompt_refiner \ + --use_lora \ + --lora_rank 16 \ + --lora_alpha 16 \ + --lora_target_modules "q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2" + +echo "LoRA inference completed. Check results/2b_metropolis_lora/generated_test_lora.mp4" \ No newline at end of file diff --git a/examples/video2world_lora.py b/examples/video2world_lora.py new file mode 100644 index 00000000..14ad1ab5 --- /dev/null +++ b/examples/video2world_lora.py @@ -0,0 +1,613 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +# Set TOKENIZERS_PARALLELISM environment variable to avoid deadlocks with multiprocessing +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import time + +import torch +from megatron.core import parallel_state +from tqdm import tqdm + +from cosmos_predict2.configs.base.config_video2world import ( + PREDICT2_VIDEO2WORLD_PIPELINE_2B, + PREDICT2_VIDEO2WORLD_PIPELINE_14B, +) +from cosmos_predict2.pipelines.video2world import _IMAGE_EXTENSIONS, _VIDEO_EXTENSIONS, Video2WorldPipeline +from imaginaire.utils import distributed, log, misc +from imaginaire.utils.io import save_image_or_video + +_DEFAULT_NEGATIVE_PROMPT = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + + +def add_lora_to_model( + model, + lora_rank=16, + lora_alpha=16, + lora_target_modules="q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2", + init_lora_weights=True, +): + """ + Add LoRA to a model using PEFT library. + + Args: + model: The model to add LoRA to + lora_rank: Rank of the LoRA adaptation + lora_alpha: Alpha parameter for LoRA + lora_target_modules: Comma-separated list of target modules + init_lora_weights: Whether to initialize LoRA weights + """ + from peft import LoraConfig, inject_adapter_in_model + + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights=init_lora_weights, + target_modules=lora_target_modules.split(","), + ) + model = inject_adapter_in_model(lora_config, model) + + # Upcast LoRA parameters to fp32 for better stability + for param in model.parameters(): + if param.requires_grad: + param.data = param.to(torch.float32) + + return model + + +def setup_lora_pipeline(config, dit_path, text_encoder_path, args): + """ + Set up a pipeline with LoRA support. + This function creates the pipeline, adds LoRA, then loads the checkpoint. + """ + from cosmos_predict2.auxiliary.cosmos_reason1 import CosmosReason1 + from cosmos_predict2.auxiliary.text_encoder import CosmosT5TextEncoder + from cosmos_predict2.models.utils import init_weights_on_device, load_state_dict + from cosmos_predict2.module.denoiser_scaling import RectifiedFlowScaling + from cosmos_predict2.schedulers.rectified_flow_scheduler import RectifiedFlowAB2Scheduler + from imaginaire.lazy_config import instantiate + from imaginaire.utils.ema import FastEmaModelUpdater + import numpy as np + + # Create a pipe + pipe = Video2WorldPipeline(device="cuda", torch_dtype=torch.bfloat16) + pipe.config = config + pipe.precision = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[config.precision] + pipe.tensor_kwargs = {"device": "cuda", "dtype": pipe.precision} + log.warning(f"precision {pipe.precision}") + + # 1. set data keys and data information + pipe.sigma_data = config.sigma_data + pipe.setup_data_key() + + # 2. setup up diffusion processing and scaling~(pre-condition) + pipe.scheduler = RectifiedFlowAB2Scheduler( + sigma_min=config.timestamps.t_min, + sigma_max=config.timestamps.t_max, + order=config.timestamps.order, + t_scaling_factor=config.rectified_flow_t_scaling_factor, + ) + + pipe.scaling = RectifiedFlowScaling(pipe.sigma_data, config.rectified_flow_t_scaling_factor) + + # 3. Set up tokenizer + pipe.tokenizer = instantiate(config.tokenizer) + assert ( + pipe.tokenizer.latent_ch == pipe.config.state_ch + ), f"latent_ch {pipe.tokenizer.latent_ch} != state_shape {pipe.config.state_ch}" + + # 4. Load text encoder + if text_encoder_path: + # inference + pipe.text_encoder = CosmosT5TextEncoder(device="cuda", cache_dir=text_encoder_path) + pipe.text_encoder.to("cuda") + else: + # training + pipe.text_encoder = None + + # 5. Initialize conditioner + pipe.conditioner = instantiate(config.conditioner) + assert ( + sum(p.numel() for p in pipe.conditioner.parameters() if p.requires_grad) == 0 + ), "conditioner should not have learnable parameters" + + # Load prompt refiner + pipe.prompt_refiner = CosmosReason1( + checkpoint_dir=config.prompt_refiner_config.checkpoint_dir, + offload_model_to_cpu=config.prompt_refiner_config.offload_model_to_cpu, + enabled=config.prompt_refiner_config.enabled, + ) + + if config.guardrail_config.enabled: + from cosmos_predict2.auxiliary.guardrail.common import presets as guardrail_presets + + pipe.text_guardrail_runner = guardrail_presets.create_text_guardrail_runner( + config.guardrail_config.checkpoint_dir, config.guardrail_config.offload_model_to_cpu + ) + pipe.video_guardrail_runner = guardrail_presets.create_video_guardrail_runner( + config.guardrail_config.checkpoint_dir, config.guardrail_config.offload_model_to_cpu + ) + else: + pipe.text_guardrail_runner = None + pipe.video_guardrail_runner = None + + # 6. Set up DiT WITHOUT loading checkpoint first + log.info("Initializing DiT model...") + with init_weights_on_device(): + dit_config = config.net + pipe.dit = instantiate(dit_config).eval() # inference + + # 7. Add LoRA to the DiT model BEFORE loading checkpoint + log.info("Adding LoRA to the DiT model...") + log.info(f"LoRA parameters: rank={args.lora_rank}, alpha={args.lora_alpha}, target_modules={args.lora_target_modules}") + + pipe.dit = add_lora_to_model( + pipe.dit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, + ) + + # 8. Handle EMA model if enabled + if config.ema.enabled: + log.info("Setting up EMA model...") + pipe.dit_ema = instantiate(dit_config).eval() + pipe.dit_ema.requires_grad_(False) + + # Add LoRA to EMA model + log.info("Adding LoRA to the EMA DiT model...") + pipe.dit_ema = add_lora_to_model( + pipe.dit_ema, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, + ) + + pipe.dit_ema_worker = FastEmaModelUpdater() # default when not using FSDP + + s = config.ema.rate + pipe.ema_exp_coefficient = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() + # copying is only necessary when starting the training at iteration 0. + # Actual state_dict should be loaded after the pipe is created. + pipe.dit_ema_worker.copy_to(src_model=pipe.dit, tgt_model=pipe.dit_ema) + + # 9. NOW load the LoRA checkpoint with strict=False + if dit_path: + log.info(f"Loading LoRA checkpoint from {dit_path}") + state_dict = load_state_dict(dit_path) + + # Split state dict for regular and EMA models + state_dict_dit_regular = dict() + state_dict_dit_ema = dict() + + for k, v in state_dict.items(): + if k.startswith("net."): + state_dict_dit_regular[k[4:]] = v + elif k.startswith("net_ema."): + state_dict_dit_ema[k[4:]] = v + + # Load regular model with strict=False to allow LoRA weights + log.info("Loading regular DiT model weights...") + missing_keys = pipe.dit.load_state_dict(state_dict_dit_regular, strict=False, assign=True) + if missing_keys.missing_keys: + log.warning(f"Missing keys in regular model: {missing_keys.missing_keys}") + if missing_keys.unexpected_keys: + log.warning(f"Unexpected keys in regular model: {missing_keys.unexpected_keys}") + + # Load EMA model if enabled + if config.ema.enabled and state_dict_dit_ema: + log.info("Loading EMA DiT model weights...") + missing_keys_ema = pipe.dit_ema.load_state_dict(state_dict_dit_ema, strict=False, assign=True) + if missing_keys_ema.missing_keys: + log.warning(f"Missing keys in EMA model: {missing_keys_ema.missing_keys}") + if missing_keys_ema.unexpected_keys: + log.warning(f"Unexpected keys in EMA model: {missing_keys_ema.unexpected_keys}") + + del state_dict, state_dict_dit_regular, state_dict_dit_ema + log.success(f"Successfully loaded LoRA checkpoint from {dit_path}") + else: + log.warning("No checkpoint path provided, using random weights") + + # 10. Move models to device + pipe.dit = pipe.dit.to(device="cuda", dtype=torch.bfloat16) + if config.ema.enabled: + pipe.dit_ema = pipe.dit_ema.to(device="cuda", dtype=torch.bfloat16) + + torch.cuda.empty_cache() + + # 11. Set up training states + if parallel_state.is_initialized(): + pipe.data_parallel_size = parallel_state.get_data_parallel_world_size() + else: + pipe.data_parallel_size = 1 + + # Print parameter counts + total_params = sum(p.numel() for p in pipe.dit.parameters()) + trainable_params = sum(p.numel() for p in pipe.dit.parameters() if p.requires_grad) + log.info(f"Total parameters: {total_params:,}") + log.info(f"Trainable LoRA parameters: {trainable_params:,}") + log.info(f"LoRA parameter ratio: {trainable_params/total_params*100:.2f}%") + + return pipe + + +def validate_input_file(input_path: str, num_conditional_frames: int) -> bool: + if not os.path.exists(input_path): + log.warning(f"Input file does not exist, skipping: {input_path}") + return False + + ext = os.path.splitext(input_path)[1].lower() + + if num_conditional_frames == 1: + # Single frame conditioning: accept both images and videos + if ext not in _IMAGE_EXTENSIONS and ext not in _VIDEO_EXTENSIONS: + log.warning( + f"Skipping file with unsupported extension for single frame conditioning: {input_path} " + f"(expected: {_IMAGE_EXTENSIONS + _VIDEO_EXTENSIONS})" + ) + return False + elif num_conditional_frames == 5: + # Multi-frame conditioning: only accept videos + if ext not in _VIDEO_EXTENSIONS: + log.warning( + f"Skipping file for multi-frame conditioning (requires video): {input_path} " + f"(expected: {_VIDEO_EXTENSIONS}, got: {ext})" + ) + return False + else: + log.error(f"Invalid num_conditional_frames: {num_conditional_frames} (must be 1 or 5)") + return False + + return True + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video-to-World Generation with Cosmos Predict2 (LoRA Support)") + parser.add_argument( + "--model_size", + choices=["2B", "14B"], + default="2B", + help="Size of the model to use for video-to-world generation", + ) + parser.add_argument( + "--resolution", + choices=["480", "720"], + default="720", + type=str, + help="Resolution of the model to use for video-to-world generation", + ) + parser.add_argument( + "--fps", + choices=[10, 16], + default=16, + type=int, + help="FPS of the model to use for video-to-world generation", + ) + parser.add_argument( + "--dit_path", + type=str, + default="", + help="Custom path to the DiT model checkpoint for post-trained models.", + ) + + # LoRA-specific arguments + parser.add_argument( + "--use_lora", + action="store_true", + help="Enable LoRA inference mode", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=16, + help="Rank of the LoRA adaptation", + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=16, + help="Alpha parameter for LoRA", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default="q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2", + help="Comma-separated list of target modules for LoRA", + ) + parser.add_argument( + "--init_lora_weights", + action="store_true", + default=True, + help="Whether to initialize LoRA weights", + ) + + parser.add_argument( + "--prompt", + type=str, + default="", + help="Text prompt for video generation", + ) + parser.add_argument( + "--input_path", + type=str, + default="assets/video2world/input0.jpg", + help="Path to input image or video for conditioning (include file extension)", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=_DEFAULT_NEGATIVE_PROMPT, + help="Negative text prompt for video-to-world generation", + ) + parser.add_argument( + "--num_conditional_frames", + type=int, + default=1, + choices=[1, 5], + help="Number of frames to condition on (1 for single frame, 5 for multi-frame conditioning)", + ) + parser.add_argument( + "--batch_input_json", + type=str, + default=None, + help="Path to JSON file containing batch inputs. Each entry should have 'input_video', 'prompt', and 'output_video' fields.", + ) + parser.add_argument("--guidance", type=float, default=7, help="Guidance value") + parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") + parser.add_argument( + "--save_path", + type=str, + default="output/generated_video.mp4", + help="Path to save the generated video (include file extension)", + ) + parser.add_argument( + "--num_gpus", + type=int, + default=1, + help="Number of GPUs to use for context parallel inference (should be a divisor of the total frames)", + ) + parser.add_argument("--disable_guardrail", action="store_true", help="Disable guardrail checks on prompts") + parser.add_argument("--offload_guardrail", action="store_true", help="Offload guardrail to CPU to save GPU memory") + parser.add_argument( + "--disable_prompt_refiner", action="store_true", help="Disable prompt refiner that enhances short prompts" + ) + parser.add_argument( + "--offload_prompt_refiner", action="store_true", help="Offload prompt refiner to CPU to save GPU memory" + ) + parser.add_argument( + "--benchmark", + action="store_true", + help="Run the generation in benchmark mode. It means that generation will be rerun a few times and the average generation time will be shown.", + ) + return parser.parse_args() + + +def setup_pipeline(args: argparse.Namespace, text_encoder=None): + log.info(f"Using model size: {args.model_size}") + if args.model_size == "2B": + config = PREDICT2_VIDEO2WORLD_PIPELINE_2B + + config.resolution = args.resolution + if args.fps == 10: # default is 16 so no need to change config + config.state_t = 16 + + dit_path = f"checkpoints/nvidia/Cosmos-Predict2-2B-Video2World/model-{args.resolution}p-{args.fps}fps.pt" + elif args.model_size == "14B": + config = PREDICT2_VIDEO2WORLD_PIPELINE_14B + + config.resolution = args.resolution + if args.fps == 10: # default is 16 so no need to change config + config.state_t = 16 + + dit_path = f"checkpoints/nvidia/Cosmos-Predict2-14B-Video2World/model-{args.resolution}p-{args.fps}fps.pt" + else: + raise ValueError("Invalid model size. Choose either '2B' or '14B'.") + + if hasattr(args, "dit_path") and args.dit_path: + dit_path = args.dit_path + + # Only set up text encoder path if no encoder is provided + text_encoder_path = None if text_encoder is not None else "checkpoints/google-t5/t5-11b" + log.info(f"Using dit_path: {dit_path}") + if text_encoder is not None: + log.info("Using provided text encoder") + else: + log.info(f"Using text encoder from: {text_encoder_path}") + + misc.set_random_seed(seed=args.seed, by_rank=True) + # Initialize cuDNN. + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + # Floating-point precision settings. + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + # Initialize distributed environment for multi-GPU inference + if hasattr(args, "num_gpus") and args.num_gpus > 1: + log.info(f"Initializing distributed environment with {args.num_gpus} GPUs for context parallelism") + + # Check if distributed environment is already initialized + if not parallel_state.is_initialized(): + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + log.info(f"Context parallel group initialized with {args.num_gpus} GPUs") + else: + log.info("Distributed environment already initialized, skipping initialization") + # Check if we need to reinitialize with different context parallel size + current_cp_size = parallel_state.get_context_parallel_world_size() + if current_cp_size != args.num_gpus: + log.warning(f"Context parallel size mismatch: current={current_cp_size}, requested={args.num_gpus}") + log.warning("Using existing context parallel configuration") + else: + log.info(f"Using existing context parallel group with {current_cp_size} GPUs") + + # Disable guardrail if requested + if args.disable_guardrail: + log.warning("Guardrail checks are disabled") + config.guardrail_config.enabled = False + config.guardrail_config.offload_model_to_cpu = args.offload_guardrail + + # Disable prompt refiner if requested + if args.disable_prompt_refiner: + log.warning("Prompt refiner is disabled") + config.prompt_refiner_config.enabled = False + config.prompt_refiner_config.offload_model_to_cpu = args.offload_prompt_refiner + + # Load models - for LoRA, we need to handle this differently + log.info(f"Initializing Video2WorldPipeline with model size: {args.model_size}") + + if args.use_lora: + # For LoRA inference, we need to add LoRA before loading the checkpoint + log.info("LoRA inference mode detected - using custom pipeline loading") + pipe = setup_lora_pipeline(config, dit_path, text_encoder_path, args) + else: + # Standard inference + pipe = Video2WorldPipeline.from_config( + config=config, + dit_path=dit_path, + text_encoder_path=text_encoder_path, + device="cuda", + torch_dtype=torch.bfloat16, + load_prompt_refiner=True, + ) + + # Set the provided text encoder if one was passed + if text_encoder is not None: + pipe.text_encoder = text_encoder + + return pipe + + +def process_single_generation( + pipe, input_path, prompt, output_path, negative_prompt, num_conditional_frames, guidance, seed, benchmark +): + # Validate input file + if not validate_input_file(input_path, num_conditional_frames): + log.warning(f"Input file validation failed: {input_path}") + return False + + log.info(f"Running Video2WorldPipeline\ninput: {input_path}\nprompt: {prompt}") + + num_repeats = 4 if benchmark else 1 + time_sum = 0 + for i in range(num_repeats): + if benchmark and i > 0: + torch.cuda.synchronize() + start_time = time.time() + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + input_path=input_path, + num_conditional_frames=num_conditional_frames, + guidance=guidance, + seed=seed, + ) + if benchmark and i > 0: + torch.cuda.synchronize() + time_sum += time.time() - start_time + if benchmark: + log.critical(f"The benchmarked generation time for Video2WorldPipeline is {time_sum / 3} seconds.") + + if video is not None: + # save the generated video + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + log.info(f"Saving generated video to: {output_path}") + if pipe.config.state_t == 16: + fps = 10 + else: + fps = 16 + save_image_or_video(video, output_path, fps=fps) + log.success(f"Successfully saved video to: {output_path}") + return True + return False + + +def generate_video(args: argparse.Namespace, pipe: Video2WorldPipeline) -> None: + if args.benchmark: + log.warning( + "Running in benchmark mode. Each generation will be rerun a couple of times and the average generation time will be shown." + ) + # Video-to-World + if args.batch_input_json is not None: + # Process batch inputs from JSON file + log.info(f"Loading batch inputs from JSON file: {args.batch_input_json}") + with open(args.batch_input_json, "r") as f: + batch_inputs = json.load(f) + + for idx, item in enumerate(tqdm(batch_inputs)): + input_video = item.get("input_video", "") + prompt = item.get("prompt", "") + output_video = item.get("output_video", f"output_{idx}.mp4") + + if not input_video or not prompt: + log.warning(f"Skipping item {idx}: Missing input_video or prompt") + continue + + process_single_generation( + pipe=pipe, + input_path=input_video, + prompt=prompt, + output_path=output_video, + negative_prompt=args.negative_prompt, + num_conditional_frames=args.num_conditional_frames, + guidance=args.guidance, + seed=args.seed, + benchmark=args.benchmark, + ) + else: + process_single_generation( + pipe=pipe, + input_path=args.input_path, + prompt=args.prompt, + output_path=args.save_path, + negative_prompt=args.negative_prompt, + num_conditional_frames=args.num_conditional_frames, + guidance=args.guidance, + seed=args.seed, + benchmark=args.benchmark, + ) + + return + + +def cleanup_distributed(): + """Clean up the distributed environment if initialized.""" + if parallel_state.is_initialized(): + parallel_state.destroy_model_parallel() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + args = parse_args() + try: + pipe = setup_pipeline(args) + generate_video(args, pipe) + finally: + # Make sure to clean up the distributed environment + cleanup_distributed() \ No newline at end of file From 33f84144efd43c7e113a466ba697c6c9034cecf6 Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Fri, 4 Jul 2025 16:22:49 +0000 Subject: [PATCH 09/12] Metropolis files --- customers/metropolis/README_BATCH_LORA.md | 157 ++++++++++++++++ customers/metropolis/run_batch_examples.sh | 49 +++++ customers/metropolis/run_batch_metropolis.py | 184 ++++++++++++++++--- extract_first_frames.sh | 58 ++++++ 4 files changed, 422 insertions(+), 26 deletions(-) create mode 100644 customers/metropolis/README_BATCH_LORA.md create mode 100755 customers/metropolis/run_batch_examples.sh mode change 100644 => 100755 customers/metropolis/run_batch_metropolis.py create mode 100755 extract_first_frames.sh diff --git a/customers/metropolis/README_BATCH_LORA.md b/customers/metropolis/README_BATCH_LORA.md new file mode 100644 index 00000000..c465cb69 --- /dev/null +++ b/customers/metropolis/README_BATCH_LORA.md @@ -0,0 +1,157 @@ +# Batch Inference with LoRA Support for Metropolis Dataset + +This directory contains scripts for running batch inference on the Metropolis dataset with both standard and LoRA inference modes. + +## Files + +- `run_batch_metropolis.py` - Main batch inference script with LoRA support +- `run_batch_examples.sh` - Interactive script with usage examples +- `README_BATCH_LORA.md` - This documentation file + +## Features + +### Standard Inference +- Uses `examples/video2world.py` script +- Default: 14B model with standard checkpoint +- Full model inference + +### LoRA Inference +- Uses `examples/video2world_lora.py` script +- Default: 2B model with LoRA checkpoint +- Parameter-efficient inference with LoRA adapters + +## Usage + +### Basic Commands + +```bash +# Standard inference (14B model) +python customers/metropolis/run_batch_metropolis.py + +# LoRA inference (2B model) +python customers/metropolis/run_batch_metropolis.py --lora +``` + +### Advanced Options + +```bash +# LoRA inference with custom parameters +python customers/metropolis/run_batch_metropolis.py \ + --lora \ + --model_size 2B \ + --lora_rank 16 \ + --lora_alpha 16 \ + --lora_target_modules "q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2" + +# Custom checkpoint path +python customers/metropolis/run_batch_metropolis.py \ + --lora \ + --dit_path "path/to/your/lora_checkpoint.pt" +``` + +### Interactive Examples + +```bash +# Run interactive examples script +./customers/metropolis/run_batch_examples.sh +``` + +## Command Line Arguments + +| Argument | Description | Default | +|----------|-------------|---------| +| `--lora` | Enable LoRA inference mode | False | +| `--model_size` | Model size (2B or 14B) | 14B (standard), 2B (LoRA) | +| `--dit_path` | Custom checkpoint path | Auto-selected based on mode | +| `--lora_rank` | LoRA rank parameter | 16 | +| `--lora_alpha` | LoRA alpha parameter | 16 | +| `--lora_target_modules` | LoRA target modules | Default attention and MLP modules | + +## Input/Output Structure + +### Input Requirements +- **Frames**: `datasets/metropolis/benchmark/v2/frames_480_640/*.jpg` +- **Prompts**: `datasets/metropolis/benchmark/v2/prompts/*.txt` +- Frame and prompt files must have matching base names + +### Output Structure +``` +output/ +├── metropolis_batch_standard/ # Standard inference outputs +│ ├── frame1.mp4 +│ ├── frame2.mp4 +│ └── ... +└── metropolis_batch_lora/ # LoRA inference outputs + ├── frame1.mp4 + ├── frame2.mp4 + └── ... +``` + +## Default Configurations + +### Standard Inference +- **Script**: `examples/video2world.py` +- **Model**: 14B +- **Checkpoint**: `checkpoints/posttraining/video2world/14b_metropolis/checkpoints/model/iter_000002000.pt` + +### LoRA Inference +- **Script**: `examples/video2world_lora.py` +- **Model**: 2B +- **Checkpoint**: `checkpoints/posttraining/video2world/2b_metropolis/checkpoints/model/iter_000002000.pt` +- **LoRA Parameters**: + - Rank: 16 + - Alpha: 16 + - Target Modules: `q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2` + +## Environment Setup + +```bash +export NUM_GPUS=8 +export PYTHONPATH=$(pwd) +``` + +## Performance Comparison + +| Mode | Model Size | Memory Usage | Speed | Quality | +|------|------------|--------------|-------|---------| +| Standard | 14B | High | Slower | Full model quality | +| LoRA | 2B + adapters | Lower | Faster | Domain-adapted quality | + +## Troubleshooting + +### Common Issues + +1. **Missing checkpoint files**: Ensure checkpoint paths exist +2. **Missing input files**: Verify frames and prompts directories exist +3. **Memory issues**: Reduce `NUM_GPUS` or use LoRA mode for lower memory usage +4. **LoRA parameter mismatch**: Ensure LoRA parameters match training configuration + +### Getting Help + +```bash +python customers/metropolis/run_batch_metropolis.py --help +``` + +## Examples + +### Quick Start - LoRA Inference +```bash +# Set environment +export NUM_GPUS=8 +export PYTHONPATH=$(pwd) + +# Run LoRA batch inference +python customers/metropolis/run_batch_metropolis.py --lora +``` + +### Custom LoRA Configuration +```bash +python customers/metropolis/run_batch_metropolis.py \ + --lora \ + --model_size 2B \ + --dit_path "checkpoints/my_custom_lora.pt" \ + --lora_rank 32 \ + --lora_alpha 32 +``` + +This setup provides flexible batch inference with both standard and LoRA modes while maintaining all the original functionality of the batch processing system. \ No newline at end of file diff --git a/customers/metropolis/run_batch_examples.sh b/customers/metropolis/run_batch_examples.sh new file mode 100755 index 00000000..6ff9c424 --- /dev/null +++ b/customers/metropolis/run_batch_examples.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Examples for running batch inference with Metropolis dataset + +echo "===================================================================" +echo "Batch Inference Examples for Metropolis Dataset" +echo "===================================================================" + +# Set environment variables +export NUM_GPUS=8 +export PYTHONPATH=$(pwd) + +echo "Choose an option:" +echo "1. Standard inference (14B model)" +echo "2. LoRA inference (2B model)" +echo "3. LoRA inference with custom parameters" +echo "4. Show help" + +read -p "Enter your choice (1-4): " choice + +case $choice in + 1) + echo "Running standard inference with 14B model..." + python customers/metropolis/run_batch_metropolis.py + ;; + 2) + echo "Running LoRA inference with 2B model..." + python customers/metropolis/run_batch_metropolis.py --lora + ;; + 3) + echo "Running LoRA inference with custom parameters..." + python customers/metropolis/run_batch_metropolis.py \ + --lora \ + --model_size 2B \ + --lora_rank 16 \ + --lora_alpha 16 \ + --lora_target_modules "q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2" + ;; + 4) + echo "Showing help..." + python customers/metropolis/run_batch_metropolis.py --help + ;; + *) + echo "Invalid choice. Please run the script again." + exit 1 + ;; +esac + +echo "Batch inference completed!" \ No newline at end of file diff --git a/customers/metropolis/run_batch_metropolis.py b/customers/metropolis/run_batch_metropolis.py old mode 100644 new mode 100755 index b3077e12..3c1df5cc --- a/customers/metropolis/run_batch_metropolis.py +++ b/customers/metropolis/run_batch_metropolis.py @@ -10,30 +10,101 @@ • torchrun is in PATH • NUM_GPUS is set below (or override via env) • Frame files are .jpg and prompt files are .txt with matching base names + +Usage: + python run_batch_metropolis.py # Standard inference + python run_batch_metropolis.py --lora # LoRA inference + python run_batch_metropolis.py --lora --model_size 2B # LoRA with 2B model """ import os import shlex import subprocess +import argparse from pathlib import Path import glob # ------------------------------------------------------------------ # CONFIG – adjust only if your paths / flags change # ------------------------------------------------------------------ -FRAMES_DIR = Path("datasets/metropolis/benchmark/v2/frames_480_640") +FRAMES_DIR = Path("datasets/metropolis/benchmark/v2/frames_1280x704") PROMPTS_DIR = Path("datasets/metropolis/benchmark/v2/prompts") OUTPUT_DIR = Path("output") NUM_GPUS = int(os.getenv("NUM_GPUS", 8)) -MODEL_SIZE = "14B" # Changed from 14B to 2B based on the checkpoint path in original -DIT_PATH = ( - "checkpoints/posttraining/video2world/14b_metropolis/" - "checkpoints/model/iter_000002000.pt" -) TORCHRUN = "torchrun" # or full path if needed -SCRIPT = "examples/video2world.py" # entry script + +# Standard inference configuration +STANDARD_CONFIG = { + "script": "examples/video2world.py", + "model_size": "2B", + "dit_path": "checkpoints/posttraining/video2world/2b_metropolis/checkpoints/model/iter_000002000.pt" +} + +# LoRA inference configuration +LORA_CONFIG = { + "script": "examples/video2world_lora.py", + "model_size": "2B", + "dit_path": "checkpoints/posttraining/video2world/2b_metropolis/checkpoints/model/iter_000003000.pt", + "lora_rank": 16, + "lora_alpha": 16, + "lora_target_modules": "q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2" +} # ------------------------------------------------------------------ +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Batch inference for Metropolis dataset", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python run_batch_metropolis.py # Standard inference with 14B model + python run_batch_metropolis.py --lora # LoRA inference with 2B model + python run_batch_metropolis.py --lora --model_size 2B # LoRA with specific model size + """ + ) + + parser.add_argument( + "--lora", + action="store_true", + help="Enable LoRA inference mode" + ) + + parser.add_argument( + "--model_size", + choices=["2B", "14B"], + help="Model size (overrides default based on inference type)" + ) + + parser.add_argument( + "--dit_path", + type=str, + help="Custom path to DiT checkpoint (overrides default)" + ) + + parser.add_argument( + "--lora_rank", + type=int, + default=16, + help="LoRA rank (only used with --lora flag)" + ) + + parser.add_argument( + "--lora_alpha", + type=int, + default=16, + help="LoRA alpha (only used with --lora flag)" + ) + + parser.add_argument( + "--lora_target_modules", + type=str, + default="q_proj,k_proj,v_proj,output_proj,mlp.layer1,mlp.layer2", + help="LoRA target modules (only used with --lora flag)" + ) + + return parser.parse_args() + def load_prompt(prompt_file: Path) -> str: """Load prompt text from file, returning first non-empty line.""" try: @@ -47,15 +118,86 @@ def load_prompt(prompt_file: Path) -> str: print(f"Warning: Could not read prompt file {prompt_file}: {e}") return "" +def get_inference_config(args): + """Get inference configuration based on arguments.""" + if args.lora: + config = LORA_CONFIG.copy() + # Override with command line arguments if provided + if args.model_size: + config["model_size"] = args.model_size + if args.dit_path: + config["dit_path"] = args.dit_path + config["lora_rank"] = args.lora_rank + config["lora_alpha"] = args.lora_alpha + config["lora_target_modules"] = args.lora_target_modules + else: + config = STANDARD_CONFIG.copy() + # Override with command line arguments if provided + if args.model_size: + config["model_size"] = args.model_size + if args.dit_path: + config["dit_path"] = args.dit_path + + return config + +def build_command(config, frame_file, prompt, output_file, args): + """Build the torchrun command based on configuration.""" + cmd = [ + TORCHRUN, + f"--nproc_per_node={NUM_GPUS}", + config["script"], + "--model_size", config["model_size"], + "--dit_path", config["dit_path"], + "--input_path", str(frame_file), + "--prompt", prompt, + "--save_path", str(output_file), + "--num_gpus", str(NUM_GPUS), + "--offload_guardrail", + "--offload_prompt_refiner", + ] + + # Add LoRA-specific parameters if using LoRA inference + if args.lora: + cmd.extend([ + "--use_lora", + "--lora_rank", str(config["lora_rank"]), + "--lora_alpha", str(config["lora_alpha"]), + "--lora_target_modules", config["lora_target_modules"] + ]) + + return cmd + def main() -> None: + # Parse command line arguments + args = parse_args() + + # Get inference configuration + config = get_inference_config(args) + + # Print configuration + inference_type = "LoRA" if args.lora else "Standard" + print(f"{'='*60}") + print(f"BATCH METROPOLIS INFERENCE - {inference_type} Mode") + print(f"{'='*60}") + print(f"Script: {config['script']}") + print(f"Model Size: {config['model_size']}") + print(f"Checkpoint: {config['dit_path']}") + if args.lora: + print(f"LoRA Rank: {config['lora_rank']}") + print(f"LoRA Alpha: {config['lora_alpha']}") + print(f"LoRA Target Modules: {config['lora_target_modules']}") + print(f"{'='*60}\n") + if not FRAMES_DIR.exists(): raise FileNotFoundError(f"Frames directory not found: {FRAMES_DIR}") if not PROMPTS_DIR.exists(): raise FileNotFoundError(f"Prompts directory not found: {PROMPTS_DIR}") - # Create output directory - OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + # Create output directory with inference type suffix + output_suffix = "_lora" if args.lora else "_standard" + output_dir = OUTPUT_DIR / f"metropolis_batch{output_suffix}" + output_dir.mkdir(parents=True, exist_ok=True) # Find all JPG files in the frames directory frame_files = list(FRAMES_DIR.glob("*.jpg")) @@ -64,6 +206,7 @@ def main() -> None: raise FileNotFoundError(f"No .jpg files found in {FRAMES_DIR}") print(f"Found {len(frame_files)} frame files to process") + print(f"Output directory: {output_dir}\n") # Process each frame file successful_jobs = 0 @@ -89,27 +232,15 @@ def main() -> None: continue # Generate output path - output_file = OUTPUT_DIR / f"{base_name}.mp4" + output_file = output_dir / f"{base_name}.mp4" - # Compose the torchrun command - cmd = [ - TORCHRUN, - f"--nproc_per_node={NUM_GPUS}", - SCRIPT, - "--model_size", MODEL_SIZE, - "--dit_path", DIT_PATH, - "--input_path", str(frame_file), - "--prompt", prompt, - "--save_path", str(output_file), - "--num_gpus", str(NUM_GPUS), - "--offload_guardrail", - "--offload_prompt_refiner", - ] + # Build command + cmd = build_command(config, frame_file, prompt, output_file, args) # Print nicely for the log pretty = " \\\n ".join(map(shlex.quote, cmd)) border = f"[{idx:02}/{len(frame_files)}]" - print(f"\n{border} Processing: {base_name}") + print(f"\n{border} Processing: {base_name} ({inference_type})") print(f"Input: {frame_file}") print(f"Prompt: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") print(f"Output: {output_file}") @@ -127,11 +258,12 @@ def main() -> None: # Print summary print(f"\n{'='*60}") - print(f"BATCH PROCESSING COMPLETE") + print(f"BATCH PROCESSING COMPLETE - {inference_type} Mode") print(f"{'='*60}") print(f"Total files found: {len(frame_files)}") print(f"Successfully processed: {successful_jobs}") print(f"Failed: {len(failed_jobs)}") + print(f"Output directory: {output_dir}") if failed_jobs: print(f"\nFailed jobs:") diff --git a/extract_first_frames.sh b/extract_first_frames.sh new file mode 100755 index 00000000..b76f30cf --- /dev/null +++ b/extract_first_frames.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# Script to extract first frame from all videos in high quality +# Source: datasets/metropolis/benchmark/v2/videos_1280x704/ +# Output: /home/arslana/codes/cosmos-predict2/datasets/metropolis/benchmark/v2/frames_1280x704/ + +SOURCE_DIR="datasets/metropolis/benchmark/v2/videos_1280x704" +OUTPUT_DIR="/home/arslana/codes/cosmos-predict2/datasets/metropolis/benchmark/v2/frames_1280x704" + +echo "==================================" +echo "Extracting First Frames (High Quality)" +echo "==================================" +echo "Source: $SOURCE_DIR" +echo "Output: $OUTPUT_DIR" +echo "==================================" + +# Create output directory if it doesn't exist +mkdir -p "$OUTPUT_DIR" + +# Counter for progress +count=0 +total=$(find "$SOURCE_DIR" -name "*.mp4" | wc -l) + +# Process each MP4 file +for video_file in "$SOURCE_DIR"/*.mp4; do + if [ -f "$video_file" ]; then + # Get base filename without extension + base_name=$(basename "$video_file" .mp4) + + # Define output path + output_file="$OUTPUT_DIR/${base_name}.jpg" + + # Increment counter + ((count++)) + + echo "[$count/$total] Processing: $base_name" + + # Extract first frame with high quality + ffmpeg -i "$video_file" \ + -vframes 1 \ + -q:v 2 \ + -y \ + "$output_file" \ + -loglevel quiet + + if [ $? -eq 0 ]; then + echo "✓ Successfully extracted: $base_name.jpg" + else + echo "✗ Failed to extract: $base_name" + fi + fi +done + +echo "==================================" +echo "Extraction Complete!" +echo "Processed: $count videos" +echo "Output directory: $OUTPUT_DIR" +echo "==================================" \ No newline at end of file From a8c6c09144f6dfdd5cce4e0eadd3e65659174a6b Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Mon, 7 Jul 2025 13:41:37 +0000 Subject: [PATCH 10/12] pre processing scripts + folders alignemnt --- .../metropolis/POST_TRAINING_README.md | 0 .../metropolis/README_metropolis.md | 0 .../metropolis/extract_captions.sh | 0 .../metropolis/extract_first_frames.sh | 0 customers/metropolis/extract_qwen_captions.py | 74 +++++++++++++++++++ customers/metropolis/run_batch_metropolis.py | 19 ++++- .../metropolis/setup_training_env.sh | 0 7 files changed, 89 insertions(+), 4 deletions(-) rename POST_TRAINING_README.md => customers/metropolis/POST_TRAINING_README.md (100%) rename README_metropolis.md => customers/metropolis/README_metropolis.md (100%) rename extract_captions.sh => customers/metropolis/extract_captions.sh (100%) rename extract_first_frames.sh => customers/metropolis/extract_first_frames.sh (100%) create mode 100755 customers/metropolis/extract_qwen_captions.py rename setup_training_env.sh => customers/metropolis/setup_training_env.sh (100%) diff --git a/POST_TRAINING_README.md b/customers/metropolis/POST_TRAINING_README.md similarity index 100% rename from POST_TRAINING_README.md rename to customers/metropolis/POST_TRAINING_README.md diff --git a/README_metropolis.md b/customers/metropolis/README_metropolis.md similarity index 100% rename from README_metropolis.md rename to customers/metropolis/README_metropolis.md diff --git a/extract_captions.sh b/customers/metropolis/extract_captions.sh similarity index 100% rename from extract_captions.sh rename to customers/metropolis/extract_captions.sh diff --git a/extract_first_frames.sh b/customers/metropolis/extract_first_frames.sh similarity index 100% rename from extract_first_frames.sh rename to customers/metropolis/extract_first_frames.sh diff --git a/customers/metropolis/extract_qwen_captions.py b/customers/metropolis/extract_qwen_captions.py new file mode 100755 index 00000000..6e3f9e1c --- /dev/null +++ b/customers/metropolis/extract_qwen_captions.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +""" +Script to extract 'qwen_caption' from all JSON files in the source directory +and save them as text files in the destination directory. +""" + +import json +import os +import glob +from pathlib import Path + +def extract_qwen_captions(): + # Define source and destination directories + source_dir = "/home/arslana/codes/cosmos-predict2/amol_its_it2/metas/v0/" + dest_dir = "/home/arslana/codes/cosmos-predict2/datasets/v3/train/metas/" + + # Get all JSON files in the source directory + json_files = glob.glob(os.path.join(source_dir, "*.json")) + + print(f"Found {len(json_files)} JSON files to process") + + processed_count = 0 + error_count = 0 + + for json_file in json_files: + try: + # Get the base filename without extension + filename = os.path.basename(json_file) + base_name = os.path.splitext(filename)[0] + + # Read the JSON file + with open(json_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # Extract qwen_caption + qwen_caption = None + + # Look for qwen_caption in the JSON structure + if 'windows' in data and isinstance(data['windows'], list): + for window in data['windows']: + if 'qwen_caption' in window: + qwen_caption = window['qwen_caption'] + break + + if qwen_caption is None: + # Try to find qwen_caption at the root level + if 'qwen_caption' in data: + qwen_caption = data['qwen_caption'] + + if qwen_caption is not None: + # Create the output text file path + output_file = os.path.join(dest_dir, f"{base_name}.txt") + + # Write the caption to the text file + with open(output_file, 'w', encoding='utf-8') as f: + f.write(qwen_caption) + + processed_count += 1 + if processed_count % 100 == 0: + print(f"Processed {processed_count} files...") + else: + print(f"Warning: No qwen_caption found in {filename}") + error_count += 1 + + except Exception as e: + print(f"Error processing {json_file}: {str(e)}") + error_count += 1 + + print(f"\nProcessing complete!") + print(f"Successfully processed: {processed_count} files") + print(f"Errors encountered: {error_count} files") + +if __name__ == "__main__": + extract_qwen_captions() \ No newline at end of file diff --git a/customers/metropolis/run_batch_metropolis.py b/customers/metropolis/run_batch_metropolis.py index 3c1df5cc..3d8fbeb2 100755 --- a/customers/metropolis/run_batch_metropolis.py +++ b/customers/metropolis/run_batch_metropolis.py @@ -23,12 +23,13 @@ import argparse from pathlib import Path import glob +import re # ------------------------------------------------------------------ # CONFIG – adjust only if your paths / flags change # ------------------------------------------------------------------ -FRAMES_DIR = Path("datasets/metropolis/benchmark/v2/frames_1280x704") -PROMPTS_DIR = Path("datasets/metropolis/benchmark/v2/prompts") +FRAMES_DIR = Path("datasets/metropolis/benchmark/v2/clean/clean_1280x704") +PROMPTS_DIR = Path("datasets/metropolis/benchmark/v2/clean/prompts_1") OUTPUT_DIR = Path("output") NUM_GPUS = int(os.getenv("NUM_GPUS", 8)) TORCHRUN = "torchrun" # or full path if needed @@ -105,6 +106,14 @@ def parse_args(): return parser.parse_args() +def extract_iteration_number(dit_path: str) -> str: + """Extract iteration number from checkpoint path.""" + # Match pattern like 'iter_000003000.pt' + match = re.search(r'iter_(\d+)\.pt', dit_path) + if match: + return match.group(1) + return "unknown" + def load_prompt(prompt_file: Path) -> str: """Load prompt text from file, returning first non-empty line.""" try: @@ -176,12 +185,14 @@ def main() -> None: # Print configuration inference_type = "LoRA" if args.lora else "Standard" + iteration_num = extract_iteration_number(config["dit_path"]) print(f"{'='*60}") print(f"BATCH METROPOLIS INFERENCE - {inference_type} Mode") print(f"{'='*60}") print(f"Script: {config['script']}") print(f"Model Size: {config['model_size']}") print(f"Checkpoint: {config['dit_path']}") + print(f"Iteration: {iteration_num}") if args.lora: print(f"LoRA Rank: {config['lora_rank']}") print(f"LoRA Alpha: {config['lora_alpha']}") @@ -194,9 +205,9 @@ def main() -> None: if not PROMPTS_DIR.exists(): raise FileNotFoundError(f"Prompts directory not found: {PROMPTS_DIR}") - # Create output directory with inference type suffix + # Create output directory with inference type and iteration suffix output_suffix = "_lora" if args.lora else "_standard" - output_dir = OUTPUT_DIR / f"metropolis_batch{output_suffix}" + output_dir = OUTPUT_DIR / f"metropolis_batch{output_suffix}_iter_{iteration_num}" output_dir.mkdir(parents=True, exist_ok=True) # Find all JPG files in the frames directory diff --git a/setup_training_env.sh b/customers/metropolis/setup_training_env.sh similarity index 100% rename from setup_training_env.sh rename to customers/metropolis/setup_training_env.sh From ae4e4bc6f4c952c1cc536acb476512b29ed742a7 Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Wed, 9 Jul 2025 12:28:41 +0000 Subject: [PATCH 11/12] resizing videos --- customers/metropolis/resizing_videos.sh | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 customers/metropolis/resizing_videos.sh diff --git a/customers/metropolis/resizing_videos.sh b/customers/metropolis/resizing_videos.sh new file mode 100644 index 00000000..02a72e73 --- /dev/null +++ b/customers/metropolis/resizing_videos.sh @@ -0,0 +1,4 @@ +for file in /home/arslana/codes/cosmos-predict2/datasets/dbq/raw/*.mp4; do + filename=$(basename "$file") + ffmpeg -i "$file" -vf "scale=1280:704:force_original_aspect_ratio=decrease,pad=1280:704:(ow-iw)/2:(oh-ih)/2:black" -c:a copy "/home/arslana/codes/cosmos-predict2/datasets/dbq/train/videos/$filename" +done From eb20e5a372e4266412b7fd04837f2172a6ee21af Mon Sep 17 00:00:00 2001 From: Arslan Ali Date: Tue, 7 Oct 2025 11:52:21 +0200 Subject: [PATCH 12/12] Update README_metropolis.md --- customers/metropolis/README_metropolis.md | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/customers/metropolis/README_metropolis.md b/customers/metropolis/README_metropolis.md index 4f4c5679..2e60fea4 100644 --- a/customers/metropolis/README_metropolis.md +++ b/customers/metropolis/README_metropolis.md @@ -85,11 +85,6 @@ torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos source setup_training_env.sh export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export EXP=predict2_video2world_training_14b_metropolis -export WANDB_ENTITY=nvidia-dir -export WANDB_API_KEY=d06eb702b1a5d31496a0ab288f8c84e2ae55510d -export WANDB_PROJECT="cosmos_predict2" -export WANDB_RUN_NAME="14b_metropolis_reduced_res" -export WANDB_MODE="online" export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True && export EXP=predict2_video2world_training_14b_metropolis && torchrun --nproc_per_node=8 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment=${EXP} dataloader_train.dataset.video_size=[480,640] @@ -336,4 +331,4 @@ For issues specific to this Metropolis dataset setup: 3. Check environment variables and dependencies 4. Monitor GPU resources and NCCL communications -For general Cosmos-Predict2 issues, refer to the main documentation in `/documentations/`. \ No newline at end of file +For general Cosmos-Predict2 issues, refer to the main documentation in `/documentations/`.