Skip to content

Commit d114559

Browse files
Refactor model_config
1 parent f6a92c1 commit d114559

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

models/tt_transformers/tt/model_config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -575,9 +575,10 @@ def __init__(
575575
max_prefill_chunk_size_div1024 = int(max_prefill_chunk_size_div1024)
576576
self.max_prefill_chunk_size = max_prefill_chunk_size_div1024 * 1024
577577

578-
if (self.base_model_name in ["Llama-3.1-8B", "Llama-3.2-11B", "Mistral-7B"] and self.device_name == "N150") or (
579-
self.base_model_name in ["Qwen2.5-7B"] and self.device_name == "N300"
580-
):
578+
if (
579+
self.base_model_name in ["Llama-3.1-8B", "Llama-3.2-11B", "Mistral-7B", "gemma-3-1b-it"]
580+
and self.device_name == "N150"
581+
) or (self.base_model_name in ["Qwen2.5-7B"] and self.device_name == "N300"):
581582
logger.info(f"Reducing prefill_len_cutoff to 512 for {self.model_name} on {self.device_name}")
582583
self.prefill_len_cutoff = 512
583584

@@ -1396,7 +1397,9 @@ def _set_params_from_dict(self, config, is_hf=False):
13961397
# Try to get text_config, if it doesn't exist everything is text config
13971398
eos_token_id = config.get("eos_token_id", None)
13981399

1399-
self.eos_token_id = None if isinstance(eos_token_id, int) else eos_token_id
1400+
self.eos_token_id = (
1401+
None if isinstance(eos_token_id, int) else eos_token_id
1402+
) # Gemma like models can have a list of eos token ids
14001403

14011404
self.sliding_window_pattern = config.get("sliding_window_pattern", 1)
14021405

@@ -2187,7 +2190,7 @@ def reference_embedding(self, reference_model=None):
21872190
model = self.reference_transformer(wrap=False)
21882191
layer = model.model.embed_tokens
21892192
else:
2190-
layer = reference_model.model.embed_tokens
2193+
layer = reference_model.model.model.embed_tokens
21912194

21922195
layer._load_state_dict = layer.load_state_dict
21932196
layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim))

0 commit comments

Comments
 (0)