@@ -575,9 +575,10 @@ def __init__(
575575 max_prefill_chunk_size_div1024 = int (max_prefill_chunk_size_div1024 )
576576 self .max_prefill_chunk_size = max_prefill_chunk_size_div1024 * 1024
577577
578- if (self .base_model_name in ["Llama-3.1-8B" , "Llama-3.2-11B" , "Mistral-7B" ] and self .device_name == "N150" ) or (
579- self .base_model_name in ["Qwen2.5-7B" ] and self .device_name == "N300"
580- ):
578+ if (
579+ self .base_model_name in ["Llama-3.1-8B" , "Llama-3.2-11B" , "Mistral-7B" , "gemma-3-1b-it" ]
580+ and self .device_name == "N150"
581+ ) or (self .base_model_name in ["Qwen2.5-7B" ] and self .device_name == "N300" ):
581582 logger .info (f"Reducing prefill_len_cutoff to 512 for { self .model_name } on { self .device_name } " )
582583 self .prefill_len_cutoff = 512
583584
@@ -1396,7 +1397,9 @@ def _set_params_from_dict(self, config, is_hf=False):
13961397 # Try to get text_config, if it doesn't exist everything is text config
13971398 eos_token_id = config .get ("eos_token_id" , None )
13981399
1399- self .eos_token_id = None if isinstance (eos_token_id , int ) else eos_token_id
1400+ self .eos_token_id = (
1401+ None if isinstance (eos_token_id , int ) else eos_token_id
1402+ ) # Gemma like models can have a list of eos token ids
14001403
14011404 self .sliding_window_pattern = config .get ("sliding_window_pattern" , 1 )
14021405
@@ -2187,7 +2190,7 @@ def reference_embedding(self, reference_model=None):
21872190 model = self .reference_transformer (wrap = False )
21882191 layer = model .model .embed_tokens
21892192 else :
2190- layer = reference_model .model .embed_tokens
2193+ layer = reference_model .model .model . embed_tokens
21912194
21922195 layer ._load_state_dict = layer .load_state_dict
21932196 layer .load_state_dict = lambda x : layer ._load_state_dict (convert_meta_to_hf (x , self .head_dim ))
0 commit comments