Skip to content

Commit 20c08ce

Browse files
Rebase Gemma-3-4b-it
1 parent 793b0fc commit 20c08ce

File tree

10 files changed

+49
-35
lines changed

10 files changed

+49
-35
lines changed

models/tt_transformers/demo/simple_vision_demo.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def test_multimodal_demo_text(
314314
total_lens = prefill_lens + max_gen_len
315315

316316
# Create padded tokens tensor for batch
317+
stop_tokens = model_args[0].tokenizer.stop_tokens
317318
pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id
318319
bsz = len(prompt_tokens)
319320
tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long)
@@ -394,8 +395,14 @@ def test_multimodal_demo_text(
394395
profiler.end(f"compile_decode", iteration=batch_idx)
395396

396397
# Disable checking for eot until I have more robust code for batch > 1
397-
# if text in ["<|eot_id|>", "<|eom_id|>"]:
398-
# break
398+
if HF_MODEL:
399+
if next_tokens in stop_tokens:
400+
break
401+
else:
402+
# Disable checking for eot until I have more robust code for batch > 1
403+
pass
404+
# if text in ["<|eot_id|>", "<|eom_id|>"]:
405+
# break
399406
_num_decode_tokens += (
400407
gen_idx * max_batch_size
401408
) # gen_idx is (num_tokens - 1) to avoid counting compile iter

models/tt_transformers/tests/test_decoder.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@ def test_decoder_inference(
8787
model_args.rope_theta,
8888
model_args.rope_scaling,
8989
)
90+
91+
if model_args.rope_local_theta is not None:
92+
rope_setup_local = RotarySetup(
93+
mesh_device,
94+
model_args.max_batch_size,
95+
model_args.head_dim,
96+
model_args.max_seq_len,
97+
model_args.rope_local_theta,
98+
None,
99+
)
100+
else:
101+
rope_setup_local = None
102+
90103
transformation_mats = rope_setup.get_both_trans_mats()
91104

92105
# Prepare page table for paged attention
@@ -172,12 +185,12 @@ def test_decoder_inference(
172185

173186
# Get cos/sin matrices for the current position of each user
174187
rot_mats = rope_setup.get_rot_mats(current_pos)
175-
188+
rot_mats_local = None if rope_setup_local is None else rope_setup_local.get_rot_mats(current_pos)
176189
# Run TT model
177190
tt_out = tt_model(
178191
decode_input,
179192
current_pos_tensor,
180-
rot_mats=rot_mats,
193+
rot_mats=[rot_mats, rot_mats_local],
181194
mode="decode",
182195
page_table=page_table_tt,
183196
)

models/tt_transformers/tests/test_decoder_prefill.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@ def test_decoder_inference(
9393
theta=model_args.rope_theta,
9494
rope_scaling=model_args.rope_scaling,
9595
)
96+
if model_args.rope_local_theta is not None:
97+
rot_mats_local = get_rot_mats(
98+
head_dim=model_args.head_dim,
99+
device=mesh_device,
100+
seq_len=max_seq_len,
101+
theta=model_args.rope_local_theta,
102+
rope_scaling=None,
103+
)
104+
else:
105+
rot_mats_local = None
96106
transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim)
97107
transformation_mats_prefill = ttnn.as_tensor(
98108
transformation_mat_torch,
@@ -168,7 +178,9 @@ def test_decoder_inference(
168178
attn_mask_torch = torch.triu(attn_mask, diagonal=1)
169179
ref_output = reference_model(pt_decode_input, positions[0], freqs_cis_i, mask=attn_mask_torch)
170180
# Run TT model
171-
tt_out = tt_model(decode_input, None, rot_mats, user_id=0, mode="prefill", page_table=page_table_tt)
181+
tt_out = tt_model(
182+
decode_input, None, [rot_mats, rot_mats_local], user_id=0, mode="prefill", page_table=page_table_tt
183+
)
172184
tt_out = ttnn.to_torch(
173185
tt_out,
174186
mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape),

models/tt_transformers/tests/test_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc)
4242
tokenizer = model_args.tokenizer
4343

4444
reference_emb = model_args.reference_embedding()
45-
if model_args.is_vision():
45+
if model_args.is_vision() and not model_args.base_model_name.startswith("gemma-3"):
4646
layer_name = "text_model.tok_embeddings.weight"
4747
else:
4848
layer_name = "tok_embeddings.weight"
@@ -68,7 +68,8 @@ def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc)
6868
dtype=ttnn.uint32,
6969
layout=ttnn.ROW_MAJOR_LAYOUT,
7070
)
71-
tt_output = tt_emb(tt_input)
71+
embed_scale = model_args.embed_scale
72+
tt_output = tt_emb(tt_input, embed_scale)
7273
tt_output_torch = ttnn.to_torch(
7374
tt_output,
7475
mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(0, -1), mesh_shape=model_args.cluster_shape),

models/tt_transformers/tt/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
import math
66
import re
7+
from enum import Enum
78
from types import SimpleNamespace
9+
from typing import Optional
810

911
import torch
1012
from llama_models.llama3.api.datatypes import ImageMedia

models/tt_transformers/tt/generator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=Non
5959
def prefill_forward_text(
6060
self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None, empty_slots=None, **kwargs
6161
):
62-
print("prefill generator ", kwargs["processed_inputs"])
6362
if page_table is not None:
6463
assert isinstance(page_table, torch.Tensor), "page_table mush be torch.Tensor"
6564

models/tt_transformers/tt/model.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,12 @@ def __init__(
6161

6262
if args.rope_local_theta is not None:
6363
self.rope_setup_local = ActualRopeSetupClass(
64-
mesh_device,
65-
args.max_batch_size,
66-
args.head_dim,
67-
args.max_seq_len,
68-
args.rope_local_theta,
69-
args.rope_scaling_factor,
70-
args.orig_context_len,
71-
"default",
64+
device=mesh_device,
65+
batch_size=args.max_batch_size,
66+
head_dim=args.head_dim,
67+
max_seq_len=args.max_seq_len,
68+
rope_theta=args.rope_local_theta,
69+
rope_scaling=None,
7270
)
7371
else:
7472
self.rope_setup_local = None

models/tt_transformers/tt/model_config.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,20 +1746,6 @@ def merge_vision_config(base_config):
17461746
self._set_vision_params(merged_vision_config)
17471747
else:
17481748
self._set_params_from_dict(self.hf_config, is_hf=True)
1749-
1750-
if "text_config" in config or "vision_config" in config:
1751-
merged_text_config = merge_text_config(config)
1752-
self._set_params_from_dict(merged_text_config, is_hf=True)
1753-
1754-
if "gemma-3-4b-it" in self.base_model_name:
1755-
self._set_vision_params(config["vision_config"])
1756-
else:
1757-
if "vision_config" in config:
1758-
merged_vision_config = merge_vision_config(config)
1759-
self._set_vision_params(merged_vision_config)
1760-
else:
1761-
self._set_params_from_dict(config, is_hf=True)
1762-
17631749
else:
17641750
config_file = os.path.join(checkpoint_dir, "config.json")
17651751
assert os.path.exists(config_file), f"config.json file not found at {config_file}"
@@ -2343,9 +2329,6 @@ def reference_transformer(self, wrap=True, load_checkpoint=False):
23432329
# Special case Qwen2.5-VL models until they are fully integrated into a HF release
23442330
if "Qwen/Qwen2.5-VL" in self.model_name:
23452331
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig as AutoConfig
2346-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
2347-
Qwen2_5_VLForConditionalGeneration as AutoModelForCausalLM,
2348-
)
23492332
else:
23502333
from transformers import AutoConfig, AutoModel
23512334

models/tt_transformers/tt/multimodal/gemma/gemma_image_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def pad_head_dim_bias(bias):
273273
memory_config=ttnn.DRAM_MEMORY_CONFIG,
274274
dtype=self.dtype,
275275
layout=ttnn.TILE_LAYOUT,
276-
# cache_file_name=cache_name("bo_sharded"),
276+
cache_file_name=cache_name("bo_sharded"),
277277
)
278278
else:
279279
self.bo = None

models/tt_transformers/tt/rope.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,6 @@ def __init__(
335335
self.batch_size = batch_size
336336
self.head_dim = head_dim
337337
self.device = device
338-
self.rope_type = rope_type
339338
self.is_mesh_device = isinstance(device, ttnn._ttnn.multi_device.MeshDevice)
340339
self.num_devices = device.get_num_devices() if self.is_mesh_device else 1
341340
if self.num_devices == 32:

0 commit comments

Comments
 (0)