|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | 5 | import os |
| 6 | +from types import SimpleNamespace |
6 | 7 | from typing import List, Union |
7 | 8 |
|
8 | 9 | import PIL |
@@ -191,6 +192,116 @@ def input_processor_for_llama_text(ctx: InputContext, inputs: Union[DecoderOnlyI |
191 | 192 | return inputs |
192 | 193 |
|
193 | 194 |
|
| 195 | +def input_processor_for_mistral_24b(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): |
| 196 | + input_processor = ctx.get_hf_processor() |
| 197 | + if "prompt" in inputs: |
| 198 | + prompt_text = inputs["prompt"] |
| 199 | + else: |
| 200 | + # [INFO] with current version of vLLM, in server mode, inputs["prompt"] gives KeyError; only inputs['prompt_token_ids'] is available |
| 201 | + assert "prompt_token_ids" in inputs, "prompt_token_ids must be available in server mode" |
| 202 | + prompt_text = input_processor.decode(inputs["prompt_token_ids"], skip_special_tokens=False) |
| 203 | + if "multi_modal_data" in inputs and "image" in inputs["multi_modal_data"]: |
| 204 | + images = inputs["multi_modal_data"]["image"] |
| 205 | + else: |
| 206 | + images = None |
| 207 | + |
| 208 | + processed_inputs = input_processor( |
| 209 | + text=prompt_text, |
| 210 | + images=images, |
| 211 | + videos=None, |
| 212 | + return_tensors="pt", |
| 213 | + ) |
| 214 | + |
| 215 | + assert processed_inputs.input_ids.shape[0] == 1, "Only one image is processed at a time by vLLM" |
| 216 | + return { |
| 217 | + "type": inputs["type"], |
| 218 | + "prompt_token_ids": processed_inputs.input_ids[0].tolist(), |
| 219 | + "prompt": prompt_text, |
| 220 | + "multi_modal_data": {"image": processed_inputs}, # [INFO] add processed_inputs |
| 221 | + } |
| 222 | + |
| 223 | + |
| 224 | +class CustomNamespace(SimpleNamespace): |
| 225 | + def __contains__(self, key): |
| 226 | + return key in self.__dict__ |
| 227 | + |
| 228 | + |
| 229 | +@INPUT_REGISTRY.register_input_processor(input_processor_for_mistral_24b) |
| 230 | +class Mistral3ForConditionalGeneration(Generator, SupportsMultiModal): |
| 231 | + def __init__(self, *args, **kwargs): |
| 232 | + super().__init__(*args, **kwargs) |
| 233 | + |
| 234 | + self.MISTRAL_IMAGE_TOKEN_ID = 151655 |
| 235 | + self.max_gen_len = self.model_args[0].max_seq_len - 1 |
| 236 | + |
| 237 | + @classmethod |
| 238 | + def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, tt_data_parallel=1): |
| 239 | + max_seq_len = 1024 * 128 |
| 240 | + |
| 241 | + submesh_devices = create_submeshes(mesh_device, tt_data_parallel) |
| 242 | + |
| 243 | + model_args = [] |
| 244 | + model = [] |
| 245 | + state_dict = None |
| 246 | + |
| 247 | + for submesh in submesh_devices: |
| 248 | + model_args_i, model_i, state_dict = create_multimodal_model( |
| 249 | + mesh_device=submesh, |
| 250 | + max_batch_size=max_batch_size // tt_data_parallel, |
| 251 | + max_seq_len=max_seq_len, |
| 252 | + use_paged_kv_cache=True, |
| 253 | + checkpoint=state_dict, |
| 254 | + ) |
| 255 | + model_args.append(model_args_i) |
| 256 | + model.append(model_i) |
| 257 | + |
| 258 | + return cls(model, model_args, mesh_device) |
| 259 | + |
| 260 | + @property |
| 261 | + def cache_path(self): |
| 262 | + return self.model_args[0].model_cache_path |
| 263 | + |
| 264 | + @property |
| 265 | + def max_cross_attn_tokens(self): |
| 266 | + return self.model_args[0].vision_max_num_chunks * nearest_32(self.model_args[0].vision_chunk_ntok) |
| 267 | + |
| 268 | + def prefill_forward(self, *args, **kwargs): |
| 269 | + self.tokenizer = self.model_args[0].tokenizer |
| 270 | + pad_token_id = self.tokenizer.pad_token_id |
| 271 | + |
| 272 | + tokens = kwargs["tokens"] |
| 273 | + prompt_lens = kwargs["prompt_lens"] |
| 274 | + inputs = CustomNamespace() |
| 275 | + inputs.input_ids = tokens |
| 276 | + data = kwargs.get("images", None) # This contains the entire Data list, not just the pixel values |
| 277 | + for i in range(tokens.shape[0]): # for each user, fix their padding |
| 278 | + tokens[i][prompt_lens[i] :] = pad_token_id |
| 279 | + pixel_values, image_sizes = None, None |
| 280 | + |
| 281 | + if hasattr(data[0], "pixel_values"): |
| 282 | + # If inputs is a list of objects with .pixel_values, concatenate them |
| 283 | + pixel_values = [im.pixel_values for im in data if hasattr(im, "pixel_values")] |
| 284 | + image_sizes = [im.image_sizes for im in data if hasattr(im, "image_sizes")] |
| 285 | + |
| 286 | + page_table = kwargs.get("page_table", None) |
| 287 | + kv_cache = kwargs.get("kv_cache", None) |
| 288 | + |
| 289 | + return super().prefill_forward_text( |
| 290 | + tokens=inputs.input_ids, |
| 291 | + page_table=page_table, |
| 292 | + kv_cache=kv_cache, |
| 293 | + prompt_lens=prompt_lens, |
| 294 | + pixel_values=pixel_values if pixel_values else None, |
| 295 | + image_sizes=image_sizes if image_sizes else None, |
| 296 | + ) |
| 297 | + |
| 298 | + def decode_forward(self, *args, **kwargs): |
| 299 | + return super().decode_forward_text(*args, **kwargs) |
| 300 | + |
| 301 | + def allocate_kv_cache(self, *args, **kwargs): |
| 302 | + return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path) |
| 303 | + |
| 304 | + |
194 | 305 | # @MULTIMODAL_REGISTRY.register_image_input_mapper() # TODO: Add once model can accept inputs from multi_modal_input_mapper (raw pixel values) |
195 | 306 | @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) |
196 | 307 | class MllamaForConditionalGeneration(Generator, SupportsMultiModal): |
|
0 commit comments