Add Spatiotemporal Area Attention (ST-A²) for V-JEPA 2 encoder#121
Add Spatiotemporal Area Attention (ST-A²) for V-JEPA 2 encoder#121tarassh wants to merge 27 commits intofacebookresearch:mainfrom
Conversation
Adapt YOLOv12's Area Attention to V-JEPA 2's 3D token grid with sparse masking. Partitions visible tokens into spatiotemporal areas by grid position, runs independent attention per area, reducing cost from O(N²) to O(num_areas × (N/num_areas)²). Compatible with 3D-RoPE, SDPA, and existing checkpoints (identical weight structure to RoPEAttention). - Add RoPEAreaAttention module with differentiable gather-pad-attend-scatter - Integrate into Block with hybrid layer allocation (area attn in early layers, full attn in final layers for global masked prediction) - Wire through VisionTransformer, init_video_model, and train.py config - Add ViT-L ablation config (2×2 factored split, layers 0-18 of 24) - Add 9-test verification suite with CPU + GPU (T4) benchmarks
PyTorch 2.9.0+cu126 renamed CudaDeviceProperties.total_mem to total_memory. Use getattr fallback for backwards compatibility.
Synthetic-data ablation comparing baseline (full attention) vs ST-A² (area attention) on a single T4 GPU. Uses the real V-JEPA 2 encoder and predictor with random video tensors — no dataset needed. Collects: loss convergence, step time, peak memory, throughput. Produces matplotlib charts and CSV exports.
Previous run at 128px/8f had only ~64 visible tokens — too small for area attention to show gains (gather/scatter overhead dominated). Now using the real V-JEPA 2 resolution with 2048 total tokens (~512 visible after masking) where O(N²) reduction should be measurable.
Without this, Colab reuses the cached vjepa2/ directory from a previous session and never fetches updated notebook config.
Summary table now prints: model, resolution, frames, batch size, token grid, visible tokens, dtype, GPU, and ST-A² area config. CSV includes all config columns so results are self-documenting.
Monkey-patches Block.forward to time attention and MLP separately across all 24 ViT-L encoder layers. Runs 20 forward passes per config, averages timings, and produces: - Per-layer table (attention type, attn ms, mlp ms, attn %) - Comparison summary (total attention vs MLP, baseline vs ST-A²) - 4-panel chart: stacked bars per layer, side-by-side attention comparison, and total time breakdown horizontal bars This reveals whether attention is actually the bottleneck at ~512 visible tokens, or if MLP/FFN dominates.
…oops Replace gather-pad-scatter with per-area Python loops with a fully vectorized sort-based implementation: 1. scatter_add_ for area counting (was: Python loop over num_areas) 2. argsort by area_id to sort all tokens (was: torch.where per batch×area) 3. Pad + reshape into (B*num_areas, heads, max_per_area, D) 4. Single batched SDPA call for all areas (was: num_areas separate calls) 5. Unsort via inverse permutation gather Eliminates all Python loops over areas (4) and batch elements (B), reducing CUDA kernel launch overhead from O(num_areas × B) to O(1). All 9 tests pass including exact equivalence with RoPEAttention when num_areas=1 (Test 5: max_diff=0.0).
Multi-resolution sweep comparing baseline vs ST-A² across 4 configs: - 256px/16f (512 visible tokens, batch=4) - 384px/16f (1,152 visible tokens, batch=2) - 256px/64f (2,048 visible tokens, batch=1) - 384px/64f (4,608 visible tokens, batch=1) Features: auto GPU/dtype detection, OOM-safe execution, per-layer profiling, speedup ratio chart with crossover analysis, CSV export. Designed for Lambda Labs 1xH100 (80GB, BF16).
Same as the notebook but runs as a plain script: python notebooks/ablation_h100_sweep.py No Jupyter required. Outputs results to stdout and CSV files.
PR_DESCRIPTION.md: Technical write-up with T4/GH200 ablation results, implementation details, and key findings (18.4% loss improvement at 384px/64f with 5.5% per-step overhead). Eval configs for frozen video classification with ST-A² encoder: - configs/eval/vitl/k400-area-attn.yaml (Kinetics-400) - configs/eval/vitl/ssv2-area-attn.yaml (Something-Something v2) Both mirror baseline configs with area attention params added to pretrain_kwargs.encoder (use_area_attention, layers 0-17, 2x2 splits).
load_checkpoint() now uses strict=False when annealing, allowing baseline vitl.pt to load into area-attention models (identical weight structure). Optimizer state is skipped during annealing to avoid key mismatches from architecture changes. New config: finetune-256px-16f-area-attn.yaml - Loads pretrained vitl.pt via annealing flow - 1,000 steps with linear LR decay (0.000525 → 1e-6) - Single-GPU setup (batch=4) for Lambda GH200 - Estimated runtime: ~5 minutes
Scans extracted K400 val directory, assigns alphabetical class labels (0-399), and writes space-delimited CSV in V-JEPA 2 VideoDataset format.
Downloads vitl.pt, K400 val set, generates CSV manifest, configures eval paths, and runs baseline vs ST-A² frozen probe evaluation. Single command: bash ~/vjepa2/scripts/setup_and_eval.sh
CVDF tars extract videos flat (no class subdirs), so the manifest script now supports --annotations flag to map filenames to labels via the K400 val.csv annotations file. Auto-detects layout.
ST-A² retains 82.4% of baseline accuracy (31.47% vs 38.20%) on K400 frozen probe without any fine-tuning — encoder was pretrained with full attention and has never seen area-partitioned patterns. Fine-tuning config ready to close the remaining gap.
Downloads K400 val, fine-tunes vitl.pt with area attention for 1,000 SSL steps, then runs frozen probe eval comparing baseline vs fine-tuned ST-A². Supports --skip-download and --eval-only flags.
Eval configs nest optimization under experiment.optimization, not top-level. Sed also mangled multi-line dataset entries in the fine-tune config. Replaced all config manipulation with a single Python script using yaml.safe_load/dump.
After 1,000 steps of SSL annealing, ST-A² outperforms baseline by 38pp (49.92% vs 11.71%) on K400 frozen probe under identical eval conditions.
Single script runs setup → fine-tune → 3 evals with identical settings: 1. Baseline (vitl.pt, full attention) 2. ST-A² no fine-tune (vitl.pt into area attention) 3. ST-A² fine-tuned (1K-step SSL annealing) All evals share batch=16, 3 HP sweeps, 3 epochs, 1 segment × 1 view. Prints comparison table at the end.
Single script runs the complete ST-A² test plan on one A100: Phase 1: Setup (clone, venv, deps, K400 download, configs) Phase 2: 9 unit tests Phase 3: Multi-resolution ablation (4 res × 2 configs, 100 steps) Phase 4: Fine-tune vitl.pt → ST-A² (1,000 steps SSL annealing) Phase 5: 3-way downstream eval (baseline, no-FT, finetuned) Phase 6: Results summary table Each phase writes a marker file on completion. Re-running the script after a crash resumes from the last incomplete phase. Use --reset to start fresh.
Replaces preliminary A10/A100 numbers with clean 3-way comparison on L40S: - Ablation sweep: +0.4% overhead at 4,608 tokens (384px/64f) - Downstream: ST-A² retains 97.4% of baseline (82.85% vs 85.02%) - Fine-tune did not improve over no-FT (insufficient data) - Reframes contribution as near-lossless efficiency optimization
|
Hi @tarassh! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
19ef925 to
24b1b04
Compare
Summary
num_areas=1Motivation
V-JEPA 2 trains with masked video modeling, where the encoder processes only visible (unmasked) tokens. At high resolutions and long temporal windows — particularly the 384px/64f cooldown phase — the visible token count reaches 4,608+, making full self-attention the dominant compute bottleneck.
Area attention offers a principled way to exploit the spatiotemporal locality inherent in video: nearby patches in space and time are more informative to each other than distant ones. By partitioning tokens into areas aligned with the 3D grid and restricting attention to within-area interactions, we reduce quadratic cost without introducing architectural asymmetry (no separate spatial/temporal heads, no window shifting logic). The approach is a drop-in replacement for standard SDPA and preserves exact numerical equivalence when disabled.
The key hypothesis is that for video SSL with masking, local attention in early layers is sufficient for feature extraction, while global attention in the final layers handles the cross-region reasoning needed for masked prediction.
Implementation
Core:
RoPEAreaAttention(src/models/utils/modules.py)The attention module assigns each of the N visible tokens to one of A =
spatial_splits × temporal_splitsareas based on its grid position (h, t). The pipeline is fully vectorized:argsortby area index to group tokens contiguously;gatherQ, K, V into sorted order(B×A, ceil(N/A), D)with zero-padding for uneven splits; construct per-area attention masks to ignore paddingF.scaled_dot_product_attentioncall across all areas simultaneouslyNo Python loops over areas. The sort/unsort overhead is ~0.8ms per layer on L40S.
Hybrid Layer Allocation
Configured via
area_attention_layers: [start, end](default[0, 18]). Layers in range useRoPEAreaAttention; layers outside use standardRoPEAttention. This gives 75% area attention layers for local feature extraction and 25% full attention layers for global masked prediction.Config Propagation
Area attention parameters flow through the existing config path:
Parameters:
use_area_attention,area_spatial_splits,area_temporal_splits,area_attention_layers,area_residual_scaleDefault Configuration
spatial_splits=2, temporal_splits=2→ 4 areas. Each area receives ~N/4 tokens, yielding a 4× reduction in per-area attention cost.Results
Multi-Resolution Ablation Sweep — L40S (48GB, BF16), 100 steps each
The per-step overhead decreases monotonically with token count: +11.5% at 512 tokens → +0.4% at 4,608 tokens. At the highest resolution where V-JEPA 2 spends its cooldown phase, area attention is essentially free in wall-clock time while reducing per-area attention FLOPs by 4×.
Downstream Evaluation — K400 Frozen Attentive Probe
To test whether ST-A² representations transfer to classification, we ran frozen probe evaluations on Kinetics-400 validation (19,877 videos, 400 classes). The encoder weights are frozen; only an attentive probe head (4 blocks, 16 heads) is trained. All three evaluations use identical settings for a fair comparison.
The
vitl.ptcheckpoint was pretrained with full attention. ST-A² evaluations load these same weights into area-attention layers. The "finetuned" variant additionally ran 1,000 steps of SSL annealing with area attention enabled on K400 val data.Setup: L40S GPU (48GB), batch=16, 1 segment × 1 view, 5 HP sweeps (lr ∈ {0.005, 0.003, 0.001, 0.0003, 0.0001}, wd=0.01), 10 epochs. All three configs identical except encoder architecture and checkpoint.
Analysis:
ST-A² (no fine-tune) retains 97.4% of baseline accuracy (82.85% vs 85.02%) — a near-lossless drop-in replacement. The encoder has never seen area-partitioned attention patterns during pretraining, yet representations transfer almost fully.
Fine-tuning did not improve over no-fine-tune (82.70% vs 82.85%). The 1,000-step SSL annealing on K400 val (~19K videos) was insufficient data to meaningfully adapt the encoder. Full pretraining with area attention from scratch (or fine-tuning on the complete data mix) would be needed to close the remaining 2.2pp gap.
The gap is consistent across training: ~0.2pp at epoch 1, ~2.2pp at epoch 10. Baseline pulls ahead slightly with more probe training, but ST-A² tracks closely throughout.
Key Findings
Near-zero overhead at high token counts: Per-step overhead decreases from +11.5% at 512 tokens to +0.4% at 4,608 tokens on L40S. At the resolution where V-JEPA 2 spends its cooldown phase, area attention is essentially free.
Near-lossless downstream transfer: ST-A² retains 97.4% of baseline K400 accuracy without any fine-tuning, confirming that area-partitioned attention preserves nearly all learned representations. The 2.2pp gap is expected to close with area-attention-native pretraining.
Scaling trend: The time overhead inversely correlates with token count — sort/unsort is O(N log N) and becomes negligible relative to the O(N²/A) attention cost at high N. This makes ST-A² most efficient exactly where V-JEPA 2 needs it most.
Drop-in compatibility:
RoPEAreaAttentionhas identical weight structure toRoPEAttention(same qkv, proj, RoPE dims), enabling direct checkpoint loading withstrict=False. No retraining required for evaluation.Next Steps
spatial_splitsandtemporal_splitsindependently (e.g., 3×1 for spatially-dominant partitioning) to find optimal area configurations per resolutionTest Plan
notebooks/test_area_attention.py— covers numerical equivalence atnum_areas=1, gradient flow, variable sequence lengths, mask correctness, and hybrid layer wiringReproduction