Skip to content

Commit b952f4d

Browse files
authored
[v1] Add PrefixLM support to FlexAttention backend (#27938)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent 541a2ef commit b952f4d

File tree

16 files changed

+173
-25
lines changed

16 files changed

+173
-25
lines changed

docs/models/supported_models.md

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -740,23 +740,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
740740
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
741741
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
742742

743-
!!! warning
744-
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
745-
However, there are differences in how they handle text + image inputs:
746-
747-
V0 correctly implements the model's attention pattern:
748-
- Uses bidirectional attention between the image tokens corresponding to the same image
749-
- Uses causal attention for other tokens
750-
- Implemented via (naive) PyTorch SDPA with masking tensors
751-
- Note: May use significant memory for long prompts with image
752-
753-
V1 currently uses a simplified attention pattern:
754-
- Uses causal attention for all tokens, including image tokens
755-
- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}`
756-
- Will be updated in the future to support the correct behavior
757-
758-
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
759-
760743
!!! note
761744
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
762745
MobileNet-v5 vision backbone.
@@ -776,9 +759,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
776759
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
777760
For more details, please see: <https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630>
778761

779-
!!! warning
780-
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
781-
782762
!!! note
783763
For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported.
784764

tests/models/multimodal/generation/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@
382382
auto_cls=AutoModelForImageTextToText,
383383
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
384384
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
385-
num_logprobs=10,
386385
),
387386
"glm4v": VLMTestInfo(
388387
models=["zai-org/glm-4v-9b"],

tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ def get_attn_backend_cls(
3030
use_mla,
3131
has_sink,
3232
use_sparse,
33+
use_mm_prefix,
3334
):
3435
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

vllm/attention/backends/abstract.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def is_mla(cls) -> bool:
166166
def supports_sink(cls) -> bool:
167167
return False
168168

169+
@classmethod
170+
def supports_mm_prefix(cls) -> bool:
171+
return False
172+
169173
@classmethod
170174
def is_sparse(cls) -> bool:
171175
return False
@@ -207,6 +211,7 @@ def validate_configuration(
207211
use_mla: bool,
208212
has_sink: bool,
209213
use_sparse: bool,
214+
use_mm_prefix: bool,
210215
device_capability: "DeviceCapability",
211216
attn_type: str,
212217
) -> list[str]:
@@ -219,6 +224,10 @@ def validate_configuration(
219224
invalid_reasons.append("kv_cache_dtype not supported")
220225
if not cls.supports_block_size(block_size):
221226
invalid_reasons.append("block_size not supported")
227+
if use_mm_prefix and not cls.supports_mm_prefix():
228+
invalid_reasons.append(
229+
"partial multimodal token full attention not supported"
230+
)
222231
if use_mla != cls.is_mla():
223232
if use_mla:
224233
invalid_reasons.append("MLA not supported")

vllm/attention/layer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def __init__(
230230
self.sliding_window = sliding_window
231231
self.has_sink = extra_impl_args.get("sinks") is not None
232232

233+
# NOTE: model_config may be None during certain tests
234+
model_config = vllm_config.model_config
235+
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
236+
233237
# During model initialization, the default dtype is set as the model
234238
# weight and activation dtype.
235239
dtype = torch.get_default_dtype()
@@ -241,6 +245,7 @@ def __init__(
241245
block_size,
242246
use_mla=False,
243247
has_sink=self.has_sink,
248+
use_mm_prefix=self.use_mm_prefix,
244249
attn_type=attn_type,
245250
)
246251
else:

vllm/attention/selector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def get_attn_backend(
2727
use_mla: bool = False,
2828
has_sink: bool = False,
2929
use_sparse: bool = False,
30+
use_mm_prefix: bool = False,
3031
attn_type: str | None = None,
3132
) -> type[AttentionBackend]:
3233
"""Selects which attention backend to use and lazily imports it."""
@@ -52,6 +53,7 @@ def get_attn_backend(
5253
use_mla=use_mla,
5354
has_sink=has_sink,
5455
use_sparse=use_sparse,
56+
use_mm_prefix=use_mm_prefix,
5557
attn_type=attn_type,
5658
)
5759

@@ -66,6 +68,7 @@ def _cached_get_attn_backend(
6668
use_mla: bool = False,
6769
has_sink: bool = False,
6870
use_sparse: bool = False,
71+
use_mm_prefix: bool = False,
6972
attn_type: str | None = None,
7073
) -> type[AttentionBackend]:
7174
from vllm.platforms import current_platform
@@ -87,6 +90,7 @@ def _cached_get_attn_backend(
8790
use_mla,
8891
has_sink,
8992
use_sparse,
93+
use_mm_prefix,
9094
attn_type,
9195
)
9296
else:
@@ -99,6 +103,7 @@ def _cached_get_attn_backend(
99103
use_mla,
100104
has_sink,
101105
use_sparse,
106+
use_mm_prefix,
102107
attn_type,
103108
)
104109
if not attention_cls:

vllm/config/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import warnings
55
from collections.abc import Callable
66
from dataclasses import InitVar, field
7+
from functools import cached_property
78
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
89

910
import torch
@@ -1217,6 +1218,19 @@ def is_deepseek_mla(self) -> bool:
12171218
)
12181219
return False
12191220

1221+
@cached_property
1222+
def is_mm_prefix_lm(self) -> bool:
1223+
"""Whether to use bidirectional attention for mm positions."""
1224+
MM_PREFIX_LM_MODELS = (
1225+
"gemma3",
1226+
# TODO(Isotr0py): Disable paligemma for now before
1227+
# we supports soft cap attention for FlexAttention
1228+
# "paligemma",
1229+
)
1230+
if not hasattr(self.hf_config, "model_type"):
1231+
return False
1232+
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
1233+
12201234
def get_head_size(self) -> int:
12211235
# TODO remove hard code
12221236
if self.is_deepseek_mla:

vllm/multimodal/inputs.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,31 @@ def get_num_embeds(self) -> int:
175175

176176
return int(self.is_embed.sum().item())
177177

178+
def extract_embeds_range(self) -> list[tuple[int, int]]:
179+
"""Extract the start and end indices of the embedded region in prompt.
180+
181+
For example, given `PlaceholderRange(offset=2, length=5)` and
182+
`is_embed = [False, True, False, True, True]`, the output is
183+
`[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.
184+
185+
Returns:
186+
A tuple `(start, end)` representing the start and end
187+
indices (inclusive) of the embedded region.
188+
Returns full placeholder range if `is_embed` is `None`.
189+
"""
190+
if self.is_embed is None:
191+
return [(self.offset, self.offset + self.length)]
192+
193+
mask_i = self.is_embed.int()
194+
starts = torch.nonzero(
195+
torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
196+
).flatten()
197+
ends = torch.nonzero(
198+
torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
199+
).flatten()
200+
ranges = torch.stack((starts, ends), dim=1) + self.offset
201+
return [tuple(x) for x in ranges.tolist()]
202+
178203
def __eq__(self, other: object) -> bool:
179204
if not isinstance(other, self.__class__):
180205
return False

vllm/platforms/cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def get_attn_backend_cls(
133133
use_mla: bool,
134134
has_sink: bool,
135135
use_sparse: bool,
136+
use_mm_prefix: bool,
136137
attn_type: str | None = None,
137138
) -> str:
138139
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:

vllm/platforms/cuda.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,20 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
233233
"Forcing kv cache block size to 64 for FlashMLASparse backend."
234234
)
235235

236+
scheduler_config = vllm_config.scheduler_config
237+
# Note: model_config may be None during testing
238+
if (
239+
model_config is not None
240+
and model_config.is_mm_prefix_lm
241+
and scheduler_config.is_multimodal_model
242+
and not scheduler_config.disable_chunked_mm_input
243+
):
244+
logger.warning(
245+
"Forcing --disable_chunked_mm_input for models "
246+
"with multimodal-bidirectional attention."
247+
)
248+
scheduler_config.disable_chunked_mm_input = True
249+
236250
@classmethod
237251
def get_current_memory_usage(
238252
cls, device: torch.types.Device | None = None
@@ -268,6 +282,7 @@ def get_valid_backends(
268282
use_mla,
269283
has_sink,
270284
use_sparse,
285+
use_mm_prefix,
271286
device_capability,
272287
attn_type,
273288
) -> tuple[
@@ -289,6 +304,7 @@ def get_valid_backends(
289304
use_mla,
290305
has_sink,
291306
use_sparse,
307+
use_mm_prefix,
292308
device_capability,
293309
attn_type,
294310
)
@@ -312,6 +328,7 @@ def get_attn_backend_cls(
312328
use_mla: bool,
313329
has_sink: bool,
314330
use_sparse: bool,
331+
use_mm_prefix: bool,
315332
attn_type: str | None = None,
316333
) -> str:
317334
if attn_type is None:
@@ -332,6 +349,7 @@ def get_attn_backend_cls(
332349
use_mla,
333350
has_sink,
334351
use_sparse,
352+
use_mm_prefix,
335353
device_capability,
336354
attn_type,
337355
)
@@ -356,6 +374,7 @@ def get_attn_backend_cls(
356374
use_mla,
357375
has_sink,
358376
use_sparse,
377+
use_mm_prefix,
359378
device_capability,
360379
attn_type,
361380
)

0 commit comments

Comments
 (0)