Skip to content

[Feature] Add Vision DP for parallel ViT computation across CP ranks#609

Open
aoshen524 wants to merge 2 commits intoradixark:mainfrom
aoshen524:feat/vision-dp
Open

[Feature] Add Vision DP for parallel ViT computation across CP ranks#609
aoshen524 wants to merge 2 commits intoradixark:mainfrom
aoshen524:feat/vision-dp

Conversation

@aoshen524
Copy link

Summary

When using Context Parallelism (cp_size > 1), Ring Flash Attention splits text attention across CP ranks, but the VisionTransformer (ViT) still processes ALL images on every rank, making ViT activation memory the bottleneck for multi-turn VLM training with many screenshots.

Vision DP distributes whole images (not patches) across CP ranks for parallelized ViT computation:

  • Before: Each of N CP ranks processes ALL images → ViT memory = O(total_images)
  • After: Each rank processes total_images/N images → ViT memory = O(total_images/N)

Key Design

  • Image-level contiguous distribution: rank 0 gets images [0,1,...], rank 1 gets next chunk. No reordering needed after all-gather.
  • Gradient scaling: backward pass scales gradients by cp_size to compensate for partial processing before FSDP reduction.
  • Automatic monkey patching: When context_parallel_size > 1 and model is a VLM, Vision DP patches are applied automatically before model loading.
  • Supported models: Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE

Files Changed

File Change
miles/utils/vision_dp.py New - Core Vision DP utilities
miles/backends/fsdp_utils/actor.py Modified - Apply Vision DP patch in FSDPTrainRayActor.init()
tests/fast/utils/test_vision_dp.py New - 19 CPU-only unit tests

Adaptation from verl

This is adapted from verl PR #5230. Key difference: verl uses Ulysses SP with global process group state (verl.utils.ulysses); Miles uses explicit CP group parameter passing via closure, which is cleaner and avoids global state coupling.

Test plan

  • 19 CPU-only unit tests pass (patch counts, image assignment, local input extraction, integration)
  • ruff, black, isort checks pass
  • E2E validation with VLM + CP > 1 training run

🤖 Generated with Claude Code

When using Context Parallelism (cp_size > 1), Ring Flash Attention splits
text attention across CP ranks, but the VisionTransformer (ViT) still
processes ALL images on every rank, making ViT memory the bottleneck for
multi-turn VLM training with many screenshots.

Vision DP distributes whole images (not patches) across CP ranks:
- Before: Each of N CP ranks processes ALL images -> O(total_images)
- After: Each rank processes total_images/N images -> O(total_images/N)

Key design:
- Image-level contiguous distribution (no reordering after all-gather)
- Gradient scaling by cp_size to compensate for FSDP reduction
- Supports Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE

Adapted from verl PR #5230.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @aoshen524, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces Vision Data Parallelism (Vision DP) to optimize memory usage in Vision-Language Model (VLM) training, particularly when using Context Parallelism. By distributing image processing across parallel ranks, it alleviates the ViT activation memory bottleneck, making multi-turn VLM training with numerous images more efficient. The implementation includes core utilities for image distribution and embedding gathering, along with automatic integration for supported VLM architectures.

Highlights

  • Vision Data Parallelism (Vision DP) Implementation: Introduced Vision DP to distribute whole images across Context Parallel (CP) ranks, significantly reducing Vision Transformer (ViT) activation memory during multi-turn VLM training.
  • Memory Optimization: Changed ViT memory usage from O(total_images) to O(total_images/N) where N is the number of CP ranks, by processing only a subset of images per rank.
  • Key Design Principles: Implemented image-level contiguous distribution, ensuring no reordering is needed after all-gather, and incorporated gradient scaling in the backward pass to compensate for partial processing.
  • Automatic Patching and Model Support: Enabled automatic monkey patching of Vision Transformer models (Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE) when context_parallel_size > 1 and the model is a VLM.
  • Comprehensive Testing: Added 19 CPU-only unit tests covering patch counts, image assignment, local input extraction, and overall integration of the Vision DP utilities.
Changelog
  • miles/backends/fsdp_utils/actor.py
    • Added logic to automatically apply Vision DP patch when Context Parallelism is enabled and the model is a Vision-Language Model (VLM).
  • miles/utils/vision_dp.py
    • Added functions to calculate image patch and embedding counts.
    • Implemented logic for contiguous image assignment to data parallel ranks.
    • Provided utilities to prepare local vision inputs for each rank.
    • Created a custom autograd function for all-gathering vision embeddings with proper gradient scaling.
    • Developed a wrapper to modify VisionTransformer forward passes for Vision DP.
    • Included a function to monkey patch specific Qwen-VL VisionTransformer models for automatic Vision DP application.
  • tests/fast/utils/test_vision_dp.py
    • Added unit tests for get_image_patch_counts.
    • Added unit tests for get_image_embedding_counts.
    • Added unit tests for assign_images_to_dp_ranks.
    • Added unit tests for prepare_local_vision_inputs.
    • Added integration tests to validate the full Vision DP workflow.
Activity
  • The author, aoshen524, implemented a new feature: Vision Data Parallelism.
  • 19 CPU-only unit tests for patch counts, image assignment, local input extraction, and integration have passed.
  • Code quality checks (ruff, black, isort) have passed.
  • E2E validation with VLM + CP > 1 training run is pending.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Vision Data Parallel (Vision DP), a strategy to parallelize Vision Transformer (ViT) computation across Context Parallel (CP) ranks by distributing whole images rather than patches. This is a significant optimization for multi-turn VLM training where ViT activation memory can become a bottleneck, effectively reducing memory usage from $O(\text{total_images})$ to $O(\text{total_images}/N)$ per rank. The implementation includes robust image assignment logic, a custom autograd function for all-gathering embeddings with gradient scaling, and monkey patching for supported Qwen-VL models. While the core logic for image distribution and gathering is sound, there is a critical concern regarding gradient synchronization across CP ranks when using 1D FSDP, which could lead to weight divergence. There are also opportunities to improve logging and the robustness of the hidden dimension detection for empty ranks.

Comment on lines +261 to +265
# Scale gradients to compensate for partial processing
# Each rank only processes 1/dp_size of the images, so gradients need to be
# scaled up by dp_size before FSDP gradient reduction (which averages them)
if grad_scaler:
grad_output = grad_output * dp_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical issue regarding gradient synchronization across CP ranks. When cp_size > 1, the ViT weights are replicated across the CP group. Since Vision DP distributes different images to each rank in the CP group, the gradients computed on each rank will differ. However, the current FSDP configuration (as seen in actor.py) only reduces gradients across the dp_mesh. This means gradients are not synchronized across the CP dimension, which will cause the replicated ViT weights to diverge during training. You likely need an all_reduce of gradients across the cp_group after the backward pass, or to include the CP dimension in the FSDP mesh (e.g., using HSDP). Additionally, the grad_scaler logic here assumes a specific normalization that might lead to incorrect magnitudes if not paired with a proper averaging reduction across the CP group.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 3ab9660. Added sync_vision_grads_across_cp() which all-reduces ViT parameter gradients (AVG) across the CP group after backward, before clip_grad_norm_ / optimizer.step().

Gradient math:

  • GatherVisionEmbeddings.backward scales output grads by cp_size, so ViT param grads = cp_size * partial_grad_k
  • After AVG across CP: mean(cp_size * partial_k) = cp_size * total_grad / cp_size = total_grad

For FSDP2 DTensor compatibility, the sync accesses ._local_tensor when present to all-reduce the sharded gradient data directly.

Comment on lines +360 to +377
if hasattr(self, "merger") and hasattr(self.merger, "ln_q"):
ln_q = self.merger.ln_q
if hasattr(ln_q, "normalized_shape"):
hidden_size = ln_q.normalized_shape[0]
elif hasattr(ln_q, "weight"):
hidden_size = ln_q.weight.shape[0]
else:
raise RuntimeError(f"Cannot determine hidden_size from ln_q. Type: {type(ln_q).__name__}")
elif hasattr(self, "out_hidden_size"):
hidden_size = self.out_hidden_size
elif hasattr(self, "config") and hasattr(self.config, "hidden_size"):
hidden_size = self.config.hidden_size
else:
raise RuntimeError(
f"Cannot determine hidden_size for VisionTransformer. "
f"Model type: {type(self).__name__}. "
f"Available attributes: {[a for a in dir(self) if not a.startswith('_')]}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining hidden_size when a rank has no images is quite fragile and highly specific to the Qwen model architecture (checking for merger.ln_q, etc.). While this works for the currently supported models, it makes the create_dp_vision_forward wrapper less 'model-agnostic' than described. Consider allowing the hidden_size to be passed as an argument or using a more standardized way to query the output dimension from the vision tower's configuration.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. The hidden_size detection logic is intentionally kept consistent with the upstream verl implementation (PR #5230) for maintainability. Since Miles currently only supports Qwen-VL model families for VLM training, these attribute paths cover all supported architectures. If new model families are added, we can extend the detection or switch to an explicit parameter approach.


original = Qwen2VisionTransformerPretrainedModel.forward
Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original, cp_group, cp_size, cp_rank)
print(f"[Vision DP] Patched Qwen2VisionTransformerPretrainedModel.forward (cp_size={cp_size})")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid using print statements for status updates in library code. It is better to use the standard logging module, which allows users to control the verbosity and format of the output. This applies to all patching confirmation messages in this function (lines 420, 432, 442, 452).

Suggested change
print(f"[Vision DP] Patched Qwen2VisionTransformerPretrainedModel.forward (cp_size={cp_size})")
logger.info(f"[Vision DP] Patched Qwen2VisionTransformerPretrainedModel.forward (cp_size={cp_size})")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 3ab9660. Replaced all print() with logger.info() using the standard logging module.

Address review comments from Gemini:

1. (Critical) Add sync_vision_grads_across_cp() to all-reduce AVG the
   vision tower parameter gradients across CP ranks after backward.
   Without this, FSDP only reduces across dp_mesh, causing ViT weights
   to diverge when Vision DP produces different gradients per CP rank.

2. (Medium) Replace print() with logger.info() in apply_vision_dp_patch.

Gradient math: GatherVisionEmbeddings backward scales output grads by
cp_size, so ViT param grads = cp_size * partial_grad. After AVG across
CP: mean(cp_size * partial_k) = total_grad. Correct.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant