@@ -373,3 +373,112 @@ def decode_forward(self, *args, **kwargs):
373373
374374 def allocate_kv_cache (self , * args , ** kwargs ):
375375 return allocate_vllm_kv_cache (* args , ** kwargs , dp_model = self .model , tt_cache_path = self .cache_path )
376+
377+
378+ def input_processor_for_gemma (ctx : InputContext , inputs : Union [DecoderOnlyInputs , EncoderDecoderInputs ]):
379+ input_processor = ctx .get_hf_processor ()
380+ if "prompt" in inputs :
381+ prompt_text = inputs ["prompt" ]
382+ else :
383+ assert "prompt_token_ids" in inputs , "prompt_token_ids must be available in server mode"
384+ prompt_text = input_processor .decode (inputs ["prompt_token_ids" ], skip_special_tokens = False )
385+
386+ if "multi_modal_data" in inputs and "image" in inputs ["multi_modal_data" ]:
387+ images = inputs ["multi_modal_data" ]["image" ]
388+ else :
389+ images = None
390+
391+ processed_inputs = input_processor (
392+ text = prompt_text ,
393+ images = images ,
394+ return_tensors = "pt" ,
395+ )
396+
397+ assert processed_inputs .input_ids .shape [0 ] == 1 , "Only one image is processed at a time by vLLM"
398+ return {
399+ "type" : inputs ["type" ],
400+ "prompt_token_ids" : processed_inputs .input_ids [0 ].tolist (),
401+ "prompt" : prompt_text ,
402+ "multi_modal_data" : {"image" : processed_inputs }, # [INFO] add processed_inputs
403+ }
404+
405+
406+ from types import SimpleNamespace
407+
408+
409+ class CustomNamespace (SimpleNamespace ):
410+ def __contains__ (self , key ):
411+ return key in self .__dict__
412+
413+
414+ @INPUT_REGISTRY .register_input_processor (input_processor_for_gemma )
415+ class Gemma3ForConditionalGeneration (Generator , SupportsMultiModal ):
416+ def __init__ (self , * args , ** kwargs ):
417+ super ().__init__ (* args , ** kwargs )
418+
419+ self .GEMMA_IMAGE_TOKEN_ID = 262144
420+ self .max_gen_len = self .model_args [0 ].max_seq_len - 1 # TODO: double check what this should be
421+
422+ @classmethod
423+ def initialize_vllm_model (
424+ cls , hf_config , mesh_device , max_batch_size , max_seq_len = 131072 , n_layers = None , tt_data_parallel = 1
425+ ):
426+ submesh_devices = create_submeshes (mesh_device , tt_data_parallel )
427+
428+ model_args = []
429+ model = []
430+ state_dict = None
431+
432+ for submesh in submesh_devices :
433+ model_args_i , model_i , state_dict = create_multimodal_model (
434+ mesh_device = submesh ,
435+ max_batch_size = max_batch_size // tt_data_parallel ,
436+ max_seq_len = max_seq_len ,
437+ use_paged_kv_cache = True ,
438+ checkpoint = state_dict ,
439+ )
440+ model_args .append (model_args_i )
441+ model .append (model_i )
442+
443+ return cls (model , model_args , mesh_device )
444+
445+ @property
446+ def cache_path (self ):
447+ return self .model_args [0 ].model_cache_path
448+
449+ def prefill_forward (self , * args , ** kwargs ):
450+ self .tokenizer = self .model_args [0 ].tokenizer
451+ pad_token_id = self .tokenizer .pad_token_id
452+
453+ tokens = kwargs ["tokens" ]
454+ prompt_lens = kwargs ["prompt_lens" ]
455+ inputs = CustomNamespace ()
456+ inputs .input_ids = tokens
457+ data = kwargs .get ("images" , None ) # This contains the entire Data list, not just the pixel values
458+ for i in range (tokens .shape [0 ]): # for each user, fix their padding
459+ tokens [i ][prompt_lens [i ] :] = pad_token_id
460+ pixel_values = None
461+
462+ if hasattr (data [0 ], "pixel_values" ):
463+ # If inputs is a list of objects with .pixel_values, concatenate them
464+ pixel_values = torch .concat ([im .pixel_values for im in data if hasattr (im , "pixel_values" )], dim = 0 )
465+
466+ page_table = kwargs .get ("page_table" , None )
467+ kv_cache = kwargs .get ("kv_cache" , None )
468+ vision_images = pixel_values
469+
470+ vision_images = [vision_images ] if vision_images is not None else None
471+
472+ return super ().prefill_forward_text (
473+ tokens = inputs .input_ids ,
474+ page_table = page_table ,
475+ kv_cache = kv_cache ,
476+ prompt_lens = prompt_lens ,
477+ pixel_values = vision_images ,
478+ )
479+
480+ def allocate_kv_cache (self , * args , ** kwargs ):
481+ return allocate_vllm_kv_cache (* args , ** kwargs , dp_model = self .model , tt_cache_path = self .cache_path )
482+
483+ def decode_forward (self , * args , ** kwargs ):
484+ return super ().decode_forward_text (* args , ** kwargs )
0 commit comments