Skip to content

Commit ead678d

Browse files
Add Performance and Accuracy metrics for Gemma-3-4b-it
1 parent 1994cf9 commit ead678d

File tree

7 files changed

+50
-5
lines changed

7 files changed

+50
-5
lines changed

.github/workflows/single-card-demo-tests-impl.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ jobs:
9292
# # Moved to t3k tests until OOM on single card runners resolved
9393
# { name: "qwen7b", runner-label: "N300", performance: false, cmd: run_qwen7b_func, owner_id: U03PUAKE719}, # Mark O'Connor
9494
{ name: "qwen25_vl", runner-label: "N300", performance: true, cmd: run_qwen25_vl_func, owner_id: U07RY6B5FLJ}, #Gongyu Wang
95+
# { name: "gemma3_4b", runner-label: "N300", performance: true, cmd: run_gemma3_4b_func, owner_id: }, # TODO Owner ID needs to be updated
96+
9597
]
9698
name: ${{ matrix.test-group.name }}-${{ matrix.test-group.runner-label }}-${{ (matrix.test-group.performance && 'perf') || 'func' }}
9799
env:

models/tt_transformers/PERF.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ This configuration uses bfp4 MLP and bfp8 attention weights for all models excep
4545
| Mistral-7B | N150 | 95 | 99 | 29.75 | 100.24 |
4646
| Mistral-7B | N300 | 95 | 99 | 47.01 | 65.95 |
4747
| Mistral-7B | T3K | 95 | 99 | 67.82 | 53.93 |
48+
| gemma-3-4b | N150 | 67.0 | 80 | 28.00 | 81.00 |
49+
| gemma-3-4b | N300 | 52.0 | 72.0 | 23.00 | 152 |
4850

4951

5052
## Accuracy
@@ -82,6 +84,8 @@ Llama 3 models test as insensitive to attention precision and so we use bfp8 att
8284
| Mistral-7B | N150 | 95 | 99 | 29.75 | 100.24 |
8385
| Mistral-7B | N300 | 95 | 99 | 47.01 | 65.95 |
8486
| Mistral-7B | T3K | 95 | 99 | 67.82 | 53.93 |
87+
| gemma-3-4b | N150 | 67.0 | 80 | 28.00 | 81.00 |
88+
| gemma-3-4b | N300 | 52.0 | 72.0 | 23.00 | 152 |
8589

8690
## Long-context (64K Tokens)
8791

models/tt_transformers/demo/simple_text_demo.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,15 @@ def test_demo_text(
945945
)
946946

947947
# Benchmark targets
948-
supported_models = ["Llama-3.2-1B", "Llama-3.2-3B", "Llama-3.1-8B", "Llama-3.2-11B", "Llama-3.1-70B", "Mistral-7B"]
948+
supported_models = [
949+
"Llama-3.2-1B",
950+
"Llama-3.2-3B",
951+
"Llama-3.1-8B",
952+
"Llama-3.2-11B",
953+
"Llama-3.1-70B",
954+
"Mistral-7B",
955+
"gemma-3-4b",
956+
]
949957
supported_devices = ["N150", "P100", "P150", "P300", "N300", "P150x4", "T3K", "TG"]
950958

951959
tt_device_name = determine_device_name(mesh_device) # submesh device should not decide performance target
@@ -994,6 +1002,9 @@ def test_demo_text(
9941002
"N300_Mistral-7B": 38, # TODO Update target
9951003
"T3K_Mistral-7B": 45, # TODO Update target
9961004
"TG_Mistral-7B": 45, # TODO Update target
1005+
#
1006+
"N150_gemma-3-4b": 23,
1007+
"N300_gemma-3-4b": 38, # TODO Update target
9971008
}
9981009
if model_device_key in dict_target_decode_tok_s_u:
9991010
target_decode_tok_s_u = dict_target_decode_tok_s_u[model_device_key]
@@ -1075,15 +1086,18 @@ def test_demo_text(
10751086
# "T3K_Qwen2.5-Coder-32B": 180, # too much variability in CI (https://github.com/tenstorrent/tt-metal/issues/24754)
10761087
# "T3K_Qwen2.5-72B": 211, # too much variability in CI (https://github.com/tenstorrent/tt-metal/issues/24754)
10771088
# "T3K_Qwen3-32B": 250, # too much variability in CI (https://github.com/tenstorrent/tt-metal/issues/24754)
1089+
"N150_gemma-3-4b": 100, # TODO Update target
10781090
}
10791091
ci_target_decode_tok_s_u = {
10801092
# N150 targets - higher is better
10811093
"N150_Llama-3.2-1B": 66,
10821094
"N150_Llama-3.2-3B": 35,
10831095
"N150_Llama-3.1-8B": 21,
10841096
"N150_Mistral-7B": 23,
1097+
"N150_gemma-3-4b": 23, # TODO Update target
10851098
# N300 targets
10861099
"N300_Qwen2.5-7B": 20,
1100+
"N300_gemma-3-4b": 20, # TODO Update target
10871101
# T3K targets
10881102
# "T3K_Llama-3.1-70B": 16, # too much variability in CI (https://github.com/tenstorrent/tt-metal/issues/24303)
10891103
# "T3K_Qwen2.5-72B": 13, # too much variability in CI (https://github.com/tenstorrent/tt-metal/issues/24303)

models/tt_transformers/demo/simple_vision_demo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,12 +480,14 @@ def test_multimodal_demo_text(
480480
"N300_Llama-3.2-11B": 23.5,
481481
"T3K_Llama-3.2-11B": 21.5,
482482
"T3K_Llama-3.2-90B": 3,
483+
"N300_gemma-3-4b": 390,
483484
}[f"{tt_device_name}_{base_model_name}"]
484485

485486
target_decode_tok_s_u = {
486487
"N300_Llama-3.2-11B": 21.5,
487488
"T3K_Llama-3.2-11B": 37,
488489
"T3K_Llama-3.2-90B": 6,
490+
"N300_gemma-3-4b": 24,
489491
}[f"{tt_device_name}_{base_model_name}"]
490492

491493
target_decode_tok_s = target_decode_tok_s_u * max_batch_size
Binary file not shown.

models/tt_transformers/tests/test_accuracy.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,27 @@ def test_tt_model_acc(
245245
theta=model_args.rope_theta,
246246
rope_scaling=model_args.rope_scaling,
247247
)
248+
249+
if model_args.rope_local_theta is not None:
250+
# If local theta is set, use it to compute the local rope matrices
251+
rot_mats_local = get_rot_mats(
252+
head_dim=model_args.head_dim,
253+
device=mesh_device,
254+
seq_len=prefill_lens[0],
255+
theta=model_args.rope_local_theta,
256+
rope_scaling=None,
257+
)
258+
else:
259+
rot_mats_local = None
260+
248261
prefill_input = model_args.prepare_residual_tensor_prefill(
249262
pt_prefill_input[batch_id],
250263
)
251264

252265
tt_out = tt_model(
253266
prefill_input,
254267
current_pos=None,
255-
rot_mats=rot_mats_prefill,
268+
rot_mats=[rot_mats_prefill, rot_mats_local],
256269
user_id=batch_id,
257270
mode="prefill",
258271
page_table=page_table_tt,
@@ -280,7 +293,7 @@ def test_tt_model_acc(
280293

281294
# Get cos/sin matrices for the current position of each user
282295
rot_mats = tt_model.rope_setup.get_rot_mats(current_pos)
283-
296+
rot_mats_local = None if tt_model.rope_setup_local is None else tt_model.rope_setup.get_rot_mats(current_pos)
284297
# Print table header
285298
if use_reference_file:
286299
logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Actual':<15}{'Top 5 Predictions':<75}")
@@ -310,7 +323,7 @@ def test_tt_model_acc(
310323
tt_out = tt_model(
311324
decode_input,
312325
current_pos_tensor,
313-
rot_mats=rot_mats,
326+
rot_mats=[rot_mats, rot_mats_local],
314327
mode="decode",
315328
page_table=page_table_tt,
316329
)
@@ -351,7 +364,9 @@ def test_tt_model_acc(
351364
# Update rot_mats for next iteration
352365
current_pos += 1
353366
rot_mats = tt_model.rope_setup.get_rot_mats(current_pos)
354-
367+
rot_mats_local = (
368+
tt_model.rope_setup_local.get_rot_mats(current_pos) if tt_model.rope_setup_local is not None else None
369+
)
355370
# Modify the accuracy checking section when using reference text
356371
if not use_reference_file:
357372
# Get probabilities from model output

tests/scripts/single_card/run_single_card_demo_tests.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ run_qwen7b_func() {
2121

2222
}
2323

24+
25+
run_gemma3_4b_func() {
26+
27+
HF_MODEL=google/gemma-3-1b-it MESH_DEVICE=N300 pytest -n auto models/tt_transformers/demo/simple_text_demo.py -k performance-ci-1 --timeout 1800
28+
29+
}
30+
31+
2432
run_qwen25_vl_func() {
2533
fail=0
2634

0 commit comments

Comments
 (0)