Skip to content

Conversation

@theoschiff
Copy link
Contributor

This PR adds a new fusion method, fusion_method="cross_attn", for the MoE image modalities (MOEImageModality and MOEImageModalityPEP) based on generalist-queried cross-attention:

  • Introduces a reusable CrossAttention module.
  • Uses the generalist CLIP (defined as last expert in configs) as the query.
  • Uses specialist CLIPs as key–value context, weighted by the gating network.
  • Keeps sequence length constant (same number of patches as a single CLIP).

This is exposed via the config flag fusion_method="cross_attn" and cross_attn_heads in both MoE configs.

What changed

  1. New CrossAttention module

    • Standard multi-head cross-attention with Q/K/V projections, dropout and output projection.
    • Supports masking and can be reused by both MoE variants.
    • Shape-safe helper _shape to handle [B, T, C] → [B, H, T, D].
  2. MOEImageModality

    • Added fusion_method="cross_attn" path:

      • Stack expert outputs → [B, E, P, C].
      • Treat last expert as generalist (g_idx = -1).
      • Use generalist patch tokens as queries: q = stacked[:, g_idx, :, :] # [B, P, C].
      • Use all non-generalist experts as specialist context.
      • Align gating weights via _gating_to_expert_perm, select specialists, and softmax over them.
      • Scale each specialist’s tokens by its gating weight before passing them as KV to CrossAttention.
    • Keeps output shape [B, P, C], then projects with the existing MLPProjector.

  3. MOEImageModalityPEP

    • Same cross_attn fusion logic, but:

      • Projects per expert first (PEP), so cross-attention operates in the shared hidden_size space.
      • Reuses the same generalist-as-query, specialists-as-context pattern.

Comparison to existing fusion strategies:

  • vs. sequence_append:

    • sequence_append linearly increases sequence length with the number of experts, which is expensive for the LLM (quadratic attention cost, more memory).
    • cross_attn keeps the same number of tokens as a single CLIP, so it’s much more scalable while still leveraging multiple experts.
  • vs. weighted_average:

    • Simple averaging is destructive: it merges all expert features per patch into a single vector, making it hard to preserve complementary information.
    • cross_attn lets the generalist CLIP decide per patch which specialists to attend to, using multi-head attention instead of a single scalar weight. This is strictly more expressive and less likely to wash out useful specialist signals.

Why “generalist CLIP as query” is a good inductive bias:

  • The generalist CLIP is trained to be robust across many domains, so using it as the query anchor keeps the final representation aligned with a strong, general embedding space.
  • Specialists contribute contextual refinements via keys/values, modulated by the gating network. This naturally matches the intuition: “generalist defines what we’re looking for, specialists provide how to refine it.”

In short:

We keep the robustness and global semantics of the generalist CLIP while letting gated specialists refine each patch via cross-attention, with no sequence length blow-up and strictly more expressive fusion than a weighted average.


Notes

  • Assumes the last expert is the generalist; this is now baked into the cross-attention path (g_idx = -1).
  • Cross-attention debug prints can be removed or guarded behind a debug flag once the method is fully validated.
  • add ablation results comparing:
    • sequence_append
    • weighted_average
    • cross_attn (generalist query) across benchmarks.

@MichelDucartier MichelDucartier merged commit 46b5772 into master Dec 10, 2025
1 check failed
@MichelDucartier MichelDucartier deleted the add-cross-attention branch December 10, 2025 15:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants