Skip to content

Commit a2848ee

Browse files
Migrate Gemma-3-4B-IT to TT-Transformers
1 parent 3fcf34b commit a2848ee

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+776
-5274
lines changed

models/common/rmsnorm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
torch_weight,
8686
device=device,
8787
dtype=weight_dtype,
88-
layout=ttnn.ROW_MAJOR_LAYOUT,
88+
layout=ttnn.TILE_LAYOUT,
8989
memory_config=weight_memory_config,
9090
cache_file_name=cache_name,
9191
mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None,
@@ -96,7 +96,7 @@ def __init__(
9696
torch_weight,
9797
device=device,
9898
dtype=weight_dtype,
99-
layout=ttnn.ROW_MAJOR_LAYOUT,
99+
layout=ttnn.TILE_LAYOUT,
100100
memory_config=weight_memory_config,
101101
cache_file_name=cache_name,
102102
mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape))
@@ -128,6 +128,11 @@ def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) ->
128128
else:
129129
assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor"
130130

131+
if x.shape[-1] % weight.shape[-1] == 0:
132+
# Reshape weight only if x's last dimension is divisible by weight's last dimension,
133+
# to avoid padding errors in RMSNorm when dimensions are not aligned
134+
weight = ttnn.reshape(weight, [1, 1, 1, -1])
135+
131136
x = norm(
132137
x,
133138
epsilon=self.eps,

models/experimental/gemma3_4b/tests/test_attention.py

Lines changed: 0 additions & 279 deletions
This file was deleted.

0 commit comments

Comments
 (0)