Add generalist-queried cross-attention fusion for MoE CLIP experts #26
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a new fusion method,
fusion_method="cross_attn", for the MoE image modalities (MOEImageModalityandMOEImageModalityPEP) based on generalist-queried cross-attention:CrossAttentionmodule.This is exposed via the config flag
fusion_method="cross_attn"andcross_attn_headsin both MoE configs.What changed
New
CrossAttentionmodule_shapeto handle[B, T, C] → [B, H, T, D].MOEImageModality
Added
fusion_method="cross_attn"path:[B, E, P, C].g_idx = -1).q = stacked[:, g_idx, :, :] # [B, P, C]._gating_to_expert_perm, select specialists, and softmax over them.CrossAttention.Keeps output shape
[B, P, C], then projects with the existingMLPProjector.MOEImageModalityPEP
Same
cross_attnfusion logic, but:hidden_sizespace.Comparison to existing fusion strategies:
vs.
sequence_append:sequence_appendlinearly increases sequence length with the number of experts, which is expensive for the LLM (quadratic attention cost, more memory).cross_attnkeeps the same number of tokens as a single CLIP, so it’s much more scalable while still leveraging multiple experts.vs.
weighted_average:cross_attnlets the generalist CLIP decide per patch which specialists to attend to, using multi-head attention instead of a single scalar weight. This is strictly more expressive and less likely to wash out useful specialist signals.Why “generalist CLIP as query” is a good inductive bias:
In short:
Notes
g_idx = -1).prints can be removed or guarded behind a debug flag once the method is fully validated.sequence_appendweighted_averagecross_attn (generalist query)across benchmarks.