@@ -459,16 +459,14 @@ def prefill_forward(self, *args, **kwargs):
459459 tokens [i ][prompt_lens [i ] :] = pad_token_id
460460 pixel_values = None
461461
462- if hasattr (data [ 0 ] , "pixel_values" ):
463- # If inputs is a list of objects with pixel_values, concatenate them
464- pixel_values = torch . concat ( [im .pixel_values for im in data if hasattr (im , "pixel_values" )], dim = 0 )
462+ if any ( hasattr (d , "pixel_values" ) for d in data ):
463+ # If inputs is a list of objects with . pixel_values, concatenate them
464+ pixel_values = [im .pixel_values if hasattr (im , "pixel_values" ) else None for im in data ]
465465
466466 page_table = kwargs .get ("page_table" , None )
467467 kv_cache = kwargs .get ("kv_cache" , None )
468468 vision_images = pixel_values
469469
470- vision_images = [vision_images ] if vision_images is not None else None
471-
472470 return super ().prefill_forward_text (
473471 tokens = inputs .input_ids ,
474472 page_table = page_table ,
@@ -482,3 +480,37 @@ def allocate_kv_cache(self, *args, **kwargs):
482480
483481 def decode_forward (self , * args , ** kwargs ):
484482 return super ().decode_forward_text (* args , ** kwargs )
483+
484+
485+ class Gemma3ForCausalLM (Generator ):
486+ def __init__ (self , * args , ** kwargs ):
487+ super ().__init__ (* args , ** kwargs )
488+
489+ @classmethod
490+ def initialize_vllm_model (
491+ cls , hf_config , mesh_device , max_batch_size , max_seq_len = 32768 , n_layers = None , tt_data_parallel = 1
492+ ):
493+ tt_model , model_args = initialize_vllm_text_transformer (
494+ hf_config ,
495+ tt_data_parallel ,
496+ mesh_device ,
497+ max_batch_size ,
498+ max_seq_len = max_seq_len ,
499+ n_layers = n_layers ,
500+ dtype = ttnn .bfloat8_b ,
501+ optimizations = DecodersPrecision .performance ,
502+ )
503+ return cls (tt_model , model_args , mesh_device )
504+
505+ @property
506+ def cache_path (self ):
507+ return self .model_args [0 ].model_cache_path
508+
509+ def prefill_forward (self , * args , ** kwargs ):
510+ return super ().prefill_forward_text (* args , ** kwargs )
511+
512+ def decode_forward (self , * args , ** kwargs ):
513+ return super ().decode_forward_text (* args , ** kwargs )
514+
515+ def allocate_kv_cache (self , * args , ** kwargs ):
516+ return allocate_vllm_kv_cache (* args , ** kwargs , dp_model = self .model , tt_cache_path = self .cache_path )
0 commit comments