Skip to content

Commit fcdfd74

Browse files
committed
vllm support for mistral 24b
1 parent 017907c commit fcdfd74

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

models/tt_transformers/tt/generator_vllm.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import os
6+
from types import SimpleNamespace
67
from typing import List, Union
78

89
import PIL
@@ -191,6 +192,116 @@ def input_processor_for_llama_text(ctx: InputContext, inputs: Union[DecoderOnlyI
191192
return inputs
192193

193194

195+
def input_processor_for_mistral_24b(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]):
196+
input_processor = ctx.get_hf_processor()
197+
if "prompt" in inputs:
198+
prompt_text = inputs["prompt"]
199+
else:
200+
# [INFO] with current version of vLLM, in server mode, inputs["prompt"] gives KeyError; only inputs['prompt_token_ids'] is available
201+
assert "prompt_token_ids" in inputs, "prompt_token_ids must be available in server mode"
202+
prompt_text = input_processor.decode(inputs["prompt_token_ids"], skip_special_tokens=False)
203+
if "multi_modal_data" in inputs and "image" in inputs["multi_modal_data"]:
204+
images = inputs["multi_modal_data"]["image"]
205+
else:
206+
images = None
207+
208+
processed_inputs = input_processor(
209+
text=prompt_text,
210+
images=images,
211+
videos=None,
212+
return_tensors="pt",
213+
)
214+
215+
assert processed_inputs.input_ids.shape[0] == 1, "Only one image is processed at a time by vLLM"
216+
return {
217+
"type": inputs["type"],
218+
"prompt_token_ids": processed_inputs.input_ids[0].tolist(),
219+
"prompt": prompt_text,
220+
"multi_modal_data": {"image": processed_inputs}, # [INFO] add processed_inputs
221+
}
222+
223+
224+
class CustomNamespace(SimpleNamespace):
225+
def __contains__(self, key):
226+
return key in self.__dict__
227+
228+
229+
@INPUT_REGISTRY.register_input_processor(input_processor_for_mistral_24b)
230+
class Mistral3ForConditionalGeneration(Generator, SupportsMultiModal):
231+
def __init__(self, *args, **kwargs):
232+
super().__init__(*args, **kwargs)
233+
234+
self.MISTRAL_IMAGE_TOKEN_ID = 151655
235+
self.max_gen_len = self.model_args[0].max_seq_len - 1
236+
237+
@classmethod
238+
def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, tt_data_parallel=1):
239+
max_seq_len = 1024 * 128
240+
241+
submesh_devices = create_submeshes(mesh_device, tt_data_parallel)
242+
243+
model_args = []
244+
model = []
245+
state_dict = None
246+
247+
for submesh in submesh_devices:
248+
model_args_i, model_i, state_dict = create_multimodal_model(
249+
mesh_device=submesh,
250+
max_batch_size=max_batch_size // tt_data_parallel,
251+
max_seq_len=max_seq_len,
252+
use_paged_kv_cache=True,
253+
checkpoint=state_dict,
254+
)
255+
model_args.append(model_args_i)
256+
model.append(model_i)
257+
258+
return cls(model, model_args, mesh_device)
259+
260+
@property
261+
def cache_path(self):
262+
return self.model_args[0].model_cache_path
263+
264+
@property
265+
def max_cross_attn_tokens(self):
266+
return self.model_args[0].vision_max_num_chunks * nearest_32(self.model_args[0].vision_chunk_ntok)
267+
268+
def prefill_forward(self, *args, **kwargs):
269+
self.tokenizer = self.model_args[0].tokenizer
270+
pad_token_id = self.tokenizer.pad_token_id
271+
272+
tokens = kwargs["tokens"]
273+
prompt_lens = kwargs["prompt_lens"]
274+
inputs = CustomNamespace()
275+
inputs.input_ids = tokens
276+
data = kwargs.get("images", None) # This contains the entire Data list, not just the pixel values
277+
for i in range(tokens.shape[0]): # for each user, fix their padding
278+
tokens[i][prompt_lens[i] :] = pad_token_id
279+
pixel_values, image_sizes = None, None
280+
281+
if hasattr(data[0], "pixel_values"):
282+
# If inputs is a list of objects with .pixel_values, concatenate them
283+
pixel_values = [im.pixel_values for im in data if hasattr(im, "pixel_values")]
284+
image_sizes = [im.image_sizes for im in data if hasattr(im, "image_sizes")]
285+
286+
page_table = kwargs.get("page_table", None)
287+
kv_cache = kwargs.get("kv_cache", None)
288+
289+
return super().prefill_forward_text(
290+
tokens=inputs.input_ids,
291+
page_table=page_table,
292+
kv_cache=kv_cache,
293+
prompt_lens=prompt_lens,
294+
pixel_values=pixel_values if pixel_values else None,
295+
image_sizes=image_sizes if image_sizes else None,
296+
)
297+
298+
def decode_forward(self, *args, **kwargs):
299+
return super().decode_forward_text(*args, **kwargs)
300+
301+
def allocate_kv_cache(self, *args, **kwargs):
302+
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)
303+
304+
194305
# @MULTIMODAL_REGISTRY.register_image_input_mapper() # TODO: Add once model can accept inputs from multi_modal_input_mapper (raw pixel values)
195306
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
196307
class MllamaForConditionalGeneration(Generator, SupportsMultiModal):

0 commit comments

Comments
 (0)