[Feature] Add Vision DP for parallel ViT computation across CP ranks#609
[Feature] Add Vision DP for parallel ViT computation across CP ranks#609aoshen524 wants to merge 2 commits intoradixark:mainfrom
Conversation
When using Context Parallelism (cp_size > 1), Ring Flash Attention splits text attention across CP ranks, but the VisionTransformer (ViT) still processes ALL images on every rank, making ViT memory the bottleneck for multi-turn VLM training with many screenshots. Vision DP distributes whole images (not patches) across CP ranks: - Before: Each of N CP ranks processes ALL images -> O(total_images) - After: Each rank processes total_images/N images -> O(total_images/N) Key design: - Image-level contiguous distribution (no reordering after all-gather) - Gradient scaling by cp_size to compensate for FSDP reduction - Supports Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE Adapted from verl PR #5230. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary of ChangesHello @aoshen524, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces Vision Data Parallelism (Vision DP) to optimize memory usage in Vision-Language Model (VLM) training, particularly when using Context Parallelism. By distributing image processing across parallel ranks, it alleviates the ViT activation memory bottleneck, making multi-turn VLM training with numerous images more efficient. The implementation includes core utilities for image distribution and embedding gathering, along with automatic integration for supported VLM architectures. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces Vision Data Parallel (Vision DP), a strategy to parallelize Vision Transformer (ViT) computation across Context Parallel (CP) ranks by distributing whole images rather than patches. This is a significant optimization for multi-turn VLM training where ViT activation memory can become a bottleneck, effectively reducing memory usage from
| # Scale gradients to compensate for partial processing | ||
| # Each rank only processes 1/dp_size of the images, so gradients need to be | ||
| # scaled up by dp_size before FSDP gradient reduction (which averages them) | ||
| if grad_scaler: | ||
| grad_output = grad_output * dp_size |
There was a problem hiding this comment.
There is a critical issue regarding gradient synchronization across CP ranks. When cp_size > 1, the ViT weights are replicated across the CP group. Since Vision DP distributes different images to each rank in the CP group, the gradients computed on each rank will differ. However, the current FSDP configuration (as seen in actor.py) only reduces gradients across the dp_mesh. This means gradients are not synchronized across the CP dimension, which will cause the replicated ViT weights to diverge during training. You likely need an all_reduce of gradients across the cp_group after the backward pass, or to include the CP dimension in the FSDP mesh (e.g., using HSDP). Additionally, the grad_scaler logic here assumes a specific normalization that might lead to incorrect magnitudes if not paired with a proper averaging reduction across the CP group.
There was a problem hiding this comment.
Fixed in 3ab9660. Added sync_vision_grads_across_cp() which all-reduces ViT parameter gradients (AVG) across the CP group after backward, before clip_grad_norm_ / optimizer.step().
Gradient math:
GatherVisionEmbeddings.backwardscales output grads bycp_size, so ViT param grads =cp_size * partial_grad_k- After AVG across CP:
mean(cp_size * partial_k) = cp_size * total_grad / cp_size = total_grad✓
For FSDP2 DTensor compatibility, the sync accesses ._local_tensor when present to all-reduce the sharded gradient data directly.
| if hasattr(self, "merger") and hasattr(self.merger, "ln_q"): | ||
| ln_q = self.merger.ln_q | ||
| if hasattr(ln_q, "normalized_shape"): | ||
| hidden_size = ln_q.normalized_shape[0] | ||
| elif hasattr(ln_q, "weight"): | ||
| hidden_size = ln_q.weight.shape[0] | ||
| else: | ||
| raise RuntimeError(f"Cannot determine hidden_size from ln_q. Type: {type(ln_q).__name__}") | ||
| elif hasattr(self, "out_hidden_size"): | ||
| hidden_size = self.out_hidden_size | ||
| elif hasattr(self, "config") and hasattr(self.config, "hidden_size"): | ||
| hidden_size = self.config.hidden_size | ||
| else: | ||
| raise RuntimeError( | ||
| f"Cannot determine hidden_size for VisionTransformer. " | ||
| f"Model type: {type(self).__name__}. " | ||
| f"Available attributes: {[a for a in dir(self) if not a.startswith('_')]}" | ||
| ) |
There was a problem hiding this comment.
The logic for determining hidden_size when a rank has no images is quite fragile and highly specific to the Qwen model architecture (checking for merger.ln_q, etc.). While this works for the currently supported models, it makes the create_dp_vision_forward wrapper less 'model-agnostic' than described. Consider allowing the hidden_size to be passed as an argument or using a more standardized way to query the output dimension from the vision tower's configuration.
There was a problem hiding this comment.
Acknowledged. The hidden_size detection logic is intentionally kept consistent with the upstream verl implementation (PR #5230) for maintainability. Since Miles currently only supports Qwen-VL model families for VLM training, these attribute paths cover all supported architectures. If new model families are added, we can extend the detection or switch to an explicit parameter approach.
miles/utils/vision_dp.py
Outdated
|
|
||
| original = Qwen2VisionTransformerPretrainedModel.forward | ||
| Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original, cp_group, cp_size, cp_rank) | ||
| print(f"[Vision DP] Patched Qwen2VisionTransformerPretrainedModel.forward (cp_size={cp_size})") |
There was a problem hiding this comment.
Avoid using print statements for status updates in library code. It is better to use the standard logging module, which allows users to control the verbosity and format of the output. This applies to all patching confirmation messages in this function (lines 420, 432, 442, 452).
| print(f"[Vision DP] Patched Qwen2VisionTransformerPretrainedModel.forward (cp_size={cp_size})") | |
| logger.info(f"[Vision DP] Patched Qwen2VisionTransformerPretrainedModel.forward (cp_size={cp_size})") |
There was a problem hiding this comment.
Fixed in 3ab9660. Replaced all print() with logger.info() using the standard logging module.
Address review comments from Gemini: 1. (Critical) Add sync_vision_grads_across_cp() to all-reduce AVG the vision tower parameter gradients across CP ranks after backward. Without this, FSDP only reduces across dp_mesh, causing ViT weights to diverge when Vision DP produces different gradients per CP rank. 2. (Medium) Replace print() with logger.info() in apply_vision_dp_patch. Gradient math: GatherVisionEmbeddings backward scales output grads by cp_size, so ViT param grads = cp_size * partial_grad. After AVG across CP: mean(cp_size * partial_k) = total_grad. Correct. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
When using Context Parallelism (
cp_size > 1), Ring Flash Attention splits text attention across CP ranks, but the VisionTransformer (ViT) still processes ALL images on every rank, making ViT activation memory the bottleneck for multi-turn VLM training with many screenshots.Vision DP distributes whole images (not patches) across CP ranks for parallelized ViT computation:
O(total_images)total_images/Nimages → ViT memory =O(total_images/N)Key Design
cp_sizeto compensate for partial processing before FSDP reduction.context_parallel_size > 1and model is a VLM, Vision DP patches are applied automatically before model loading.Files Changed
miles/utils/vision_dp.pymiles/backends/fsdp_utils/actor.pyFSDPTrainRayActor.init()tests/fast/utils/test_vision_dp.pyAdaptation from verl
This is adapted from verl PR #5230. Key difference: verl uses Ulysses SP with global process group state (
verl.utils.ulysses); Miles uses explicit CP group parameter passing via closure, which is cleaner and avoids global state coupling.Test plan
🤖 Generated with Claude Code