Skip to content

Commit c079a67

Browse files
Rebase Gemma-3-4b-it
1 parent 304ad10 commit c079a67

21 files changed

+138
-71
lines changed

models/experimental/gemma3_4b/tests/test_attention.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def test_attention_inference(
9292
model_args.head_dim,
9393
model_args.max_seq_len,
9494
model_args.rope_theta,
95-
model_args.rope_scaling_factor,
96-
model_args.orig_context_len,
95+
model_args.rope_scaling,
9796
)
9897

9998
transformation_mats = rope_setup.get_both_trans_mats()
@@ -141,8 +140,8 @@ def test_attention_inference(
141140
model_args.head_dim,
142141
model_args.max_seq_len * 2,
143142
model_args.rope_theta,
144-
model_args.rope_scaling_factor,
145-
model_args.orig_context_len,
143+
model_args.rope_scaling.factor if model_args.rope_scaling else None,
144+
model_args.rope_scaling.original_max_position_embeddings if model_args.rope_scaling else None,
146145
)
147146
freqs_cis = torch.complex(cos, sin)
148147

models/experimental/gemma3_4b/tests/test_decoder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def test_decoder_inference(
8282
model_args.head_dim,
8383
model_args.max_seq_len,
8484
model_args.rope_theta,
85-
model_args.rope_scaling_factor,
86-
model_args.orig_context_len,
85+
model_args.rope_scaling,
8786
)
8887
transformation_mats = rope_setup.get_both_trans_mats()
8988

models/experimental/gemma3_4b/tests/test_mlp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,13 @@ def test_mlp_inference(seq_len, batch_size, reset_seeds, device):
4747
state_dict = tt_model_args.load_state_dict()
4848

4949
# # Ref model needs partial state dict, but our models use full state dict keys as cached weight names
50-
first_layer_prefix = "layers.0.feed_forward"
50+
# first_layer_prefix = "layers.0.feed_forward"
51+
first_layer_prefix = tt_model_args.get_state_dict_prefix("MLP", 0)
52+
5153
partial_state_dict = {
5254
k[len(first_layer_prefix) + 1 :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix))
5355
}
56+
5457
reference_model = tt_model_args.reference_mlp() # Gemma3 MLP
5558
reference_model.load_state_dict(partial_state_dict)
5659

models/experimental/gemma3_4b/tests/vision_tests/test_end2end.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,18 @@ def setup_vision_prompts_and_tokenizer(model_args, instruct):
169169
}
170170
]
171171

172-
# messages = [
173-
# {
174-
# "role": "user",
175-
# "content": [
176-
# {
177-
# "type": "image",
178-
# "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
179-
# },
180-
# {"type": "text", "text": "Describe this image in detail."},
181-
# ],
182-
# }
183-
# ]
172+
messages = [
173+
{
174+
"role": "user",
175+
"content": [
176+
{
177+
"type": "image",
178+
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
179+
},
180+
{"type": "text", "text": "Describe this image in detail."},
181+
],
182+
}
183+
]
184184

185185
tokenizer = model_args.tokenizer
186186
return messages, tokenizer
@@ -211,7 +211,7 @@ def process_real_vision_inputs(messages, model_args):
211211
).to(dtype=torch.bfloat16)
212212

213213
input_ids = encoded["input_ids"]
214-
pixel_values = None
214+
pixel_values = encoded["pixel_values"]
215215
attention_mask = encoded["attention_mask"]
216216

217217
# logger.info(f"Processed vision inputs - input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}")

models/experimental/gemma3_4b/tt/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""
2+
source: models/tt_transformers/tt/attention.py
3+
24
This is the attention implementation of the Gemma-3-4b-it
35
46
We have re-used the Attention implementation of the TT-Transformers with few modifications.

models/experimental/gemma3_4b/tt/decoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""
2+
source: models/tt_transformers/tt/decoder.py
3+
24
This is the Decoder block for the gemma 3-4b-it model
35
We couldn't use the existing implementation in TT-Transformers because the usage of submodules is different
46

models/experimental/gemma3_4b/tt/gemma3_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""
2+
source: models/tt_transformers/tt/generator.py
3+
24
This is the Replica version of the Generator class for the Gemma Model.
35
This adds support for kwargs that contains the procesed inputs and the vision submodule of the model.
46

models/experimental/gemma3_4b/tt/gemma_conv2d_patch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""
2+
source: models/tt_transformers/tt/multimodal/llama_conv2d_patch.py
23
This is the Conv2dPath of Gemma-3-4b-it
34
We have reused the exisiting Conv2dPath of TtLlamaConv2dPath with few modifications.
45
We have added a check for weight to convert 4D to 2D

models/experimental/gemma3_4b/tt/gemma_image_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""
2+
source: models/tt_transformers/tt/multimodal/llama_image_attention.py
3+
24
This is the ImageAttention block for Gemma-3-4b-it
35
We have reused the TTLlamaImageAttention with some modification.
46
We have made the linears (Q,K,V) to be executed separately and added bias support for O_projection, along with few

0 commit comments

Comments
 (0)