Skip to content

Commit 359c925

Browse files
Gemma model files clean up
1 parent f1a0685 commit 359c925

File tree

6 files changed

+18
-7
lines changed

6 files changed

+18
-7
lines changed

models/tt_transformers/tests/test_attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,9 @@ def test_attention_inference(
164164
# 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1
165165
pt_attention_input = torch.randn(
166166
batch_size, seq_len, model_args.dim, dtype=get_ref_model_dype(reference_model, model_args.model_name)
167-
).to(
168-
torch.bfloat16
169167
) # Qwen2.5 0.5B sees 0.1 to 2.1
170-
168+
if "gemma" in os.environ.get("HF_MODEL"):
169+
pt_attention_input = pt_attention_input.to(torch.bfloat16)
171170
tt_attention_input = pt_attention_input.clone()
172171

173172
attention_input = model_args.prepare_residual_tensor_decode(

models/tt_transformers/tests/test_attention_prefill.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def test_attention_inference(
145145
)
146146
* 2
147147
) - 1
148-
pt_attention_input = pt_attention_input.to(torch.bfloat16) # Qwen2.5 0.5B sees 0.1 to 2.1
148+
if "gemma" in os.environ.get("HF_MODEL"):
149+
pt_attention_input = pt_attention_input.to(torch.bfloat16)
149150
tt_attention_input = pt_attention_input.clone()
150151
attention_input = model_args.prepare_residual_tensor_prefill(
151152
tt_attention_input,

models/tt_transformers/tests/test_decoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ def test_decoder_inference(
168168
)
169169
* 2
170170
) - 1
171-
pt_decode_input = pt_decode_input.to(torch.bfloat16)
171+
if "gemma" in os.environ.get("HF_MODEL"):
172+
pt_decode_input = pt_decode_input.to(torch.bfloat16)
172173
tt_decode_input = pt_decode_input.clone()
173174

174175
decode_input = model_args.prepare_residual_tensor_decode(

models/tt_transformers/tests/test_decoder_prefill.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def test_decoder_inference(
155155
)
156156
* 2
157157
) - 1
158-
pt_decode_input = pt_decode_input.to(torch.bfloat16)
158+
if "gemma" in os.environ.get("HF_MODEL"):
159+
pt_decode_input = pt_decode_input.to(torch.bfloat16)
159160
tt_decode_input = pt_decode_input.clone()
160161
decode_input = model_args.prepare_residual_tensor_prefill(
161162
tt_decode_input,

models/tt_transformers/tests/test_lm_head.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def test_lm_head_inference(seq_len, batch_size, mesh_device, reset_seeds):
6464
max_columns_per_device=model_args.max_columns_per_device_lm_head,
6565
)
6666

67-
torch_input = torch.randn(1, 1, seq_len, model_args.dim).to(torch.bfloat16)
67+
torch_input = torch.randn(1, 1, seq_len, model_args.dim)
68+
if "gemma" in os.environ.get("HF_MODEL"):
69+
torch_input = torch_input.to(torch.bfloat16)
6870
reference_output = reference_model(torch_input)
6971
tt_input = ttnn.from_torch(
7072
torch_input,

models/tt_transformers/tt/model_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,8 @@ def vision_chunk_ntok(self):
15501550
"""
15511551
Returns the number of tokens per chunk, accounting for the extra class token
15521552
"""
1553+
if self.is_llama_vision():
1554+
return (self.vision_chunk_size // self.vision_patch_size) ** 2 + 1
15531555
return (self.image_size // self.vision_patch_size) ** 2 + 1
15541556

15551557
def _set_model_params(self, checkpoint_dir):
@@ -1683,7 +1685,12 @@ def __repr__(self):
16831685
)"""
16841686

16851687
# TODO: Rename to is_llama_vision
1688+
def is_llama_vision(self):
1689+
return self.vision_chunk_size > 0
1690+
16861691
def is_vision(self):
1692+
if self.is_llama_vision():
1693+
return True
16871694
return self.image_size > 0
16881695

16891696
def get_state_dict_prefix(self, module_name, layer_num, is_vision=False):

0 commit comments

Comments
 (0)