Skip to content

Commit 0e17632

Browse files
Migrate Gemma-3-1b-it to TT-Transformers Library
1 parent ade214f commit 0e17632

File tree

12 files changed

+191
-58
lines changed

12 files changed

+191
-58
lines changed

models/common/rmsnorm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
torch_weight,
8383
device=device,
8484
dtype=weight_dtype,
85-
layout=ttnn.ROW_MAJOR_LAYOUT,
85+
layout=ttnn.TILE_LAYOUT,
8686
memory_config=weight_memory_config,
8787
cache_file_name=cache_name,
8888
mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None,
@@ -93,7 +93,7 @@ def __init__(
9393
torch_weight,
9494
device=device,
9595
dtype=weight_dtype,
96-
layout=ttnn.ROW_MAJOR_LAYOUT,
96+
layout=ttnn.TILE_LAYOUT,
9797
memory_config=weight_memory_config,
9898
cache_file_name=cache_name,
9999
mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape))
@@ -125,6 +125,11 @@ def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) ->
125125
else:
126126
assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor"
127127

128+
if x.shape[-1] % weight.shape[-1] == 0:
129+
# Reshape weight only if x's last dimension is divisible by weight's last dimension,
130+
# to avoid padding errors in RMSNorm when dimensions are not aligned
131+
weight = ttnn.reshape(weight, [1, 1, 1, -1])
132+
128133
x = norm(
129134
x,
130135
epsilon=self.eps,

models/tt_transformers/tests/test_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_decoder_inference(
177177
tt_out = tt_model(
178178
decode_input,
179179
current_pos_tensor,
180-
rot_mats=rot_mats,
180+
rot_mats=[rot_mats, rot_mats],
181181
mode="decode",
182182
page_table=page_table_tt,
183183
)
@@ -191,7 +191,7 @@ def test_decoder_inference(
191191
freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0)
192192

193193
# Reference model
194-
ref_output = reference_model(pt_decode_input, current_pos[0], freqs_cis_i, mask=None)
194+
ref_output = reference_model(pt_decode_input.to(dtype=torch.bfloat16), current_pos[0], freqs_cis_i, mask=None)
195195

196196
passing, pcc_message = comp_pcc(ref_output, tt_output_torch)
197197

models/tt_transformers/tests/test_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc)
5858

5959
prompts = ["Joy"] * 32
6060
pt_input = torch.tensor([model_args.encode_prompt(prompt, instruct=False) for prompt in prompts])
61+
embed_scale = model_args.embed_scale
6162
reference_output = reference_emb(pt_input)
6263
logger.info(f"reference_output: {reference_output.shape}")
6364

@@ -68,7 +69,7 @@ def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc)
6869
dtype=ttnn.uint32,
6970
layout=ttnn.ROW_MAJOR_LAYOUT,
7071
)
71-
tt_output = tt_emb(tt_input)
72+
tt_output = tt_emb(tt_input, embed_scale)
7273
tt_output_torch = ttnn.to_torch(
7374
tt_output,
7475
mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(0, -1), mesh_shape=model_args.cluster_shape),

models/tt_transformers/tt/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
use_paged_kv_cache=False,
2828
):
2929
super().__init__()
30+
self.is_sliding = configuration.is_sliding[layer_num]
3031

3132
self.state_dict = state_dict
3233
self.mesh_device = mesh_device

models/tt_transformers/tt/common.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,25 +238,31 @@ def compute_llama3_parameters(freqs: torch.Tensor, scale_factor: float, orig_con
238238
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
239239

240240

241+
def compute_linear_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int):
242+
"""Linear scaling for rotary embeddings."""
243+
freqs /= scale_factor
244+
return freqs
245+
246+
241247
def compute_default_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int):
242248
"""Default scaling for rotary embeddings."""
243249
return freqs
244250

245251

246-
def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int):
252+
def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int, rope_type="llama3"):
247253
# FIXME: Llama-3.x specific scaling - we need to support yarn for Qwen2.5 models
248254

249-
hf_model_env = os.getenv("HF_MODEL")
250-
251-
if hf_model_env == "google/gemma-3-1b-it":
255+
if rope_type == "default":
252256
freqs = compute_default_parameters(freqs, scale_factor, orig_context_len)
253-
elif "LLAMA_DIR" in os.environ or (hf_model_env and "llama" in hf_model_env.lower()):
257+
elif rope_type == "linear":
258+
freqs = compute_linear_parameters(freqs, scale_factor, orig_context_len)
259+
elif rope_type == "llama3":
254260
freqs = compute_llama3_parameters(freqs, scale_factor, orig_context_len)
255261

256262
return freqs
257263

258264

259-
def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len):
265+
def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len, rope_type="llama3"):
260266
"""
261267
Precompute the frequency tensor for sine and cosine values with given dimensions.
262268
@@ -271,7 +277,7 @@ def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len):
271277
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
272278
t = torch.arange(end)
273279
if scale_factor is not None:
274-
freqs = apply_scaling(freqs, scale_factor, orig_context_len)
280+
freqs = apply_scaling(freqs, scale_factor, orig_context_len, rope_type=rope_type)
275281
freqs = torch.outer(t, freqs).float()
276282
return torch.cos(freqs), torch.sin(freqs)
277283

models/tt_transformers/tt/decoder.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,53 @@ def __init__(
102102
args,
103103
TG=args.is_galaxy,
104104
)
105+
if f"layers.{layer_num}.pre_feedforward_layernorm.weight" in self.state_dict:
106+
self.pre_ff_norm = DistributedNorm( # pre_feedforward_layernorm
107+
RMSNorm(
108+
device=mesh_device,
109+
dim=args.dim,
110+
eps=args.norm_eps,
111+
state_dict=state_dict,
112+
add_unit_offset=self.args.rms_norm_add_unit_offset,
113+
state_dict_prefix=args.get_state_dict_prefix("", layer_num),
114+
weight_cache_path=None if args.dummy_weights else weight_cache_path,
115+
weight_dtype=ttnn.bfloat16,
116+
weight_key="pre_feedforward_layernorm",
117+
is_distributed=self.args.is_distributed_norm,
118+
sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"],
119+
sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"],
120+
ccl_topology=self.args.ccl_topology(),
121+
),
122+
args,
123+
TG=args.is_galaxy,
124+
)
125+
else:
126+
# If pre_feedforward_layernorm is not in state_dict, we do not use it
127+
self.pre_ff_norm = None
128+
129+
if f"layers.{layer_num}.post_feedforward_layernorm.weight" in self.state_dict:
130+
self.post_ff_norm = DistributedNorm( # post_feedforward_layernorm
131+
RMSNorm(
132+
device=mesh_device,
133+
dim=args.dim,
134+
eps=args.norm_eps,
135+
add_unit_offset=self.args.rms_norm_add_unit_offset,
136+
state_dict=state_dict,
137+
state_dict_prefix=args.get_state_dict_prefix("", layer_num),
138+
weight_cache_path=None if args.dummy_weights else weight_cache_path,
139+
weight_dtype=ttnn.bfloat16,
140+
weight_key="post_feedforward_layernorm",
141+
is_distributed=self.args.is_distributed_norm,
142+
sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"],
143+
sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"],
144+
ccl_topology=self.args.ccl_topology(),
145+
),
146+
args,
147+
TG=args.is_galaxy,
148+
)
149+
else:
150+
# If post_feedforward_layernorm is not in state_dict, we do not use it
151+
self.post_ff_norm = None
105152

106153
def forward(
107154
self,
@@ -116,6 +163,7 @@ def forward(
116163
kv_cache=None,
117164
) -> ttnn.Tensor:
118165
TG = self.args.is_galaxy
166+
residual = x
119167
# x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode)
120168
skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG
121169
assert (
@@ -124,36 +172,53 @@ def forward(
124172
# Norms take fractured inputs and output replicated across devices
125173
attn_in = self.attention_norm(x, mode)
126174
# Attention takes replicated inputs and produces fractured outputs
175+
if self.attention.is_sliding:
176+
position_embeddings = rot_mats[1]
177+
else:
178+
position_embeddings = rot_mats[0]
179+
127180
attn_out = self.attention.forward(
128181
attn_in,
129182
current_pos,
130-
rot_mats,
183+
position_embeddings,
131184
user_id,
132185
mode,
133186
page_table=page_table,
134187
chunk_page_table=chunk_page_table,
135188
chunk_start_idx=chunk_start_idx,
136189
kv_cache=kv_cache,
137190
)
138-
# Here x and attn_out are both fractured across devices
139-
h = ttnn.add(x, attn_out, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16 if TG else None)
140-
ttnn.deallocate(attn_out)
191+
if self.pre_ff_norm == None:
192+
attn_out = ttnn.add(x, attn_out, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16 if TG else None)
193+
194+
residual = attn_out
195+
196+
hidden_states = self.ff_norm(attn_out, mode)
197+
if self.pre_ff_norm is not None:
198+
hidden_states = ttnn.add(hidden_states, residual, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16)
199+
200+
residual = hidden_states
201+
202+
hidden_states = self.pre_ff_norm(hidden_states, mode)
203+
141204
if mode == "prefill":
142205
x.deallocate(True)
143206

144-
# Norms take fractured inputs and output replicated across devices
145-
ff_in = self.ff_norm(h, mode)
207+
# ttnn.deallocate(attn_out)
208+
146209
if TG and mode == "decode":
147-
ff_in = ttnn.to_memory_config(ff_in, memory_config=self.model_config["MLP_ACT_MEMCFG"])
210+
hidden_states = ttnn.to_memory_config(hidden_states, memory_config=self.model_config["MLP_ACT_MEMCFG"])
148211
# MLP takes replicated inputs and produces fractured outputs
149-
ff_out = self.feed_forward.forward(ff_in, mode)
150-
# ff_out and h are both fractured across devices
212+
hidden_states = self.feed_forward.forward(hidden_states, mode)
151213
activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype(
152214
decoder_id=self.layer_num, tensor=TensorGroup.ACTIVATION
153215
)
216+
if self.post_ff_norm is not None:
217+
hidden_states = self.post_ff_norm(hidden_states, mode)
218+
154219
out = ttnn.add(
155-
h,
156-
ff_out,
220+
residual,
221+
hidden_states,
157222
memory_config=skip_mem_cfg,
158223
dtype=self.args.ccl_dtype
159224
if TG and not self.args.is_distributed_norm(mode)

models/tt_transformers/tt/embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
cache_file_name=cache_name,
3434
)
3535

36-
def forward(self, x: ttnn.Tensor) -> ttnn.Tensor:
36+
def forward(self, x: ttnn.Tensor, embed_scale: int = 1.0) -> ttnn.Tensor:
3737
x = ttnn.embedding(x, self.weights, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
38+
x = ttnn.multiply(x, embed_scale)
3839
return x

models/tt_transformers/tt/lm_head.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
self.num_devices = args.num_devices
3232

3333
size_per_device = self.vocab_size // self.num_devices
34+
self.model_config = args.get_model_config()
3435

3536
if args.is_galaxy:
3637
size_per_device = self.padded_vocab_size // self.num_devices
@@ -138,12 +139,14 @@ def forward(self, x: ttnn.Tensor):
138139
compute_kernel_config=self.compute_kernel_config,
139140
program_config=pc,
140141
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
141-
dtype=ttnn.bfloat8_b,
142+
dtype=self.args.lm_head_dtype or ttnn.bfloat8_b,
143+
)
144+
outputs.append(
145+
ttnn.sharded_to_interleaved(output, memory_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"])
142146
)
143-
outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.L1_MEMORY_CONFIG))
144147

145148
# Concatenate the outputs
146-
output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG)
149+
output = ttnn.concat(outputs, dim=-1, memory_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"])
147150

148151
output = tt_all_reduce(
149152
output,

models/tt_transformers/tt/mlp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def __init__(
7272
self.w3 = as_sharded_tensor("w3_sharded", ff1_3_dtype, dims=w1_dims)
7373

7474
# Default activation is SILU
75-
self.activation_type = self.args.mlp_activation_type
75+
self.activation_type = (
76+
args.mlp_activation_type if hasattr(args, "mlp_activation_type") else ttnn.UnaryOpType.SILU
77+
)
7678

7779
def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
7880
"""

models/tt_transformers/tt/model.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,31 @@ def __init__(
4949
dtype=ttnn.bfloat16, # Row major layout requires bfloat16
5050
)
5151

52-
ActualRopeSetupClass = rope_setup_class if rope_setup_class is not None else RotarySetup
53-
self.rope_setup = ActualRopeSetupClass(
54-
device=mesh_device,
55-
batch_size=args.max_batch_size,
56-
head_dim=args.head_dim,
57-
max_seq_len=args.max_seq_len,
58-
rope_theta=args.rope_theta,
59-
rope_scaling=args.rope_scaling,
52+
self.rope_setup = RotarySetup(
53+
mesh_device,
54+
args.max_batch_size,
55+
args.head_dim,
56+
args.max_seq_len,
57+
args.rope_theta,
58+
args.rope_scaling_factor,
59+
args.orig_context_len,
60+
args.rope_type,
6061
)
62+
63+
if args.rope_local_theta is not None:
64+
self.rope_setup_local = RotarySetup(
65+
mesh_device,
66+
args.max_batch_size,
67+
args.head_dim,
68+
args.max_seq_len,
69+
args.rope_local_theta,
70+
args.rope_scaling_factor,
71+
args.orig_context_len,
72+
"default",
73+
)
74+
else:
75+
self.rope_setup_local = None
76+
6177
self.trans_mats_dict = self.rope_setup.get_both_trans_mats()
6278

6379
self.layers = [
@@ -105,6 +121,8 @@ def __init__(
105121
max_columns_per_device=self.args.max_columns_per_device_lm_head,
106122
)
107123

124+
self.embed_scale = args.embed_scale
125+
108126
def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None):
109127
"""
110128
Inputs are torch tensors or python types. This function returns ttnn
@@ -122,7 +140,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag
122140
layout=ttnn.ROW_MAJOR_LAYOUT,
123141
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
124142
)
125-
tokens_embd = self.embd(tokens)
143+
tokens_embd = self.embd(tokens, self.embed_scale)
144+
126145
tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd)
127146

128147
# Slice the rot mats to the prefill seqlen
@@ -133,6 +152,13 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag
133152
self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :],
134153
self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :],
135154
]
155+
if self.rope_setup_local is not None:
156+
tt_rot_mats_prefill_local = [
157+
self.rope_setup_local.cos_matrix[:, :, start_pos : start_pos + S, :],
158+
self.rope_setup_local.sin_matrix[:, :, start_pos : start_pos + S, :],
159+
]
160+
else:
161+
tt_rot_mats_prefill_local = None
136162

137163
if page_table is not None:
138164
tt_page_table = ttnn.from_torch(
@@ -156,7 +182,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag
156182
else:
157183
tt_chunk_page_table = None
158184

159-
return tokens_embd, tt_rot_mats_prefill, tt_page_table, tt_chunk_page_table
185+
return tokens_embd, [tt_rot_mats_prefill, tt_rot_mats_prefill_local], tt_page_table, tt_chunk_page_table
160186

161187
def prepare_inputs_decode(self, *inputs):
162188
"""
@@ -228,13 +254,18 @@ def transform_decode_inputs_device(self, tokens, current_pos, rope_idxs, page_ta
228254
Embed tokens
229255
"""
230256
tt_rot_mats = self.rope_setup.get_rot_mats(rope_idxs)
231-
tt_tokens = self.embd(tokens)
257+
if self.rope_setup_local is not None:
258+
tt_rot_mats_local = self.rope_setup_local.get_rot_mats(rope_idxs)
259+
else:
260+
tt_rot_mats_local = None
261+
tt_tokens = self.embd(tokens, self.embed_scale)
262+
232263
tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens)
233264
tt_tokens = ttnn.to_memory_config(
234265
tt_tokens,
235266
self.args.model_config["DECODE_RESIDUAL_MEMCFG"],
236267
)
237-
return tt_tokens, current_pos, tt_rot_mats, page_table
268+
return tt_tokens, current_pos, [tt_rot_mats, tt_rot_mats_local], page_table
238269

239270
def concat_device_output(self, tt_out):
240271
"""

0 commit comments

Comments
 (0)