From ee72d2685bbcac556f1af24b66ff20d2661fa6be Mon Sep 17 00:00:00 2001 From: Mohammed Taher Rasheed Date: Mon, 4 Aug 2025 15:18:39 +0000 Subject: [PATCH 1/4] Add base commit for gemma 1b --- models/experimental/gemma3_1b/tt/attention.py | 893 ++++++++++++++++++ models/experimental/gemma3_1b/tt/lm_head.py | 161 ++++ models/experimental/gemma3_1b/tt/mlp.py | 254 +++++ 3 files changed, 1308 insertions(+) create mode 100644 models/experimental/gemma3_1b/tt/attention.py create mode 100644 models/experimental/gemma3_1b/tt/lm_head.py create mode 100644 models/experimental/gemma3_1b/tt/mlp.py diff --git a/models/experimental/gemma3_1b/tt/attention.py b/models/experimental/gemma3_1b/tt/attention.py new file mode 100644 index 000000000000..47ba6a7d95fd --- /dev/null +++ b/models/experimental/gemma3_1b/tt/attention.py @@ -0,0 +1,893 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.common.rmsnorm import RMSNorm +from models.tt_transformers.tt.ccl import tt_all_gather, tt_all_reduce +from models.tt_transformers.tt.model_config import OpGroup, TensorGroup + + +class Attention(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + weight_cache_path, + layer_num, + dtype, + transformation_mats, + configuration, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.num_devices = configuration.num_devices + self.TG = self.num_devices == 32 + self.hidden_size = configuration.dim + self.n_heads = configuration.n_heads + self.head_dim = configuration.head_dim + self.max_seq_len = configuration.max_seq_len + self.max_batch_size = configuration.max_batch_size + self.n_kv_heads = configuration.n_kv_heads + self.paged_attention_config = paged_attention_config + self.min_kv_prefill_shard_seqlen = configuration.min_kv_prefill_shard_seqlen + self.ccl_dtype = configuration.ccl_dtype + self.num_reduce_scatter_links = configuration.num_reduce_scatter_links + self.num_all_gather_links = configuration.num_all_gather_links + self.MAX_QKV_MM_SEQ_LEN = configuration.MAX_QKV_MM_SEQ_LEN + self.tile_size = configuration.tile_size + self.rms_norm_add_unit_offset = configuration.rms_norm_add_unit_offset + self.num_device_groups = self.num_devices // self.n_kv_heads + self.num_devices_per_group = self.n_kv_heads if self.TG else self.num_devices + self.batch_size_per_device_group = ( + max(self.max_batch_size // self.num_device_groups, 1) if self.TG else self.max_batch_size + ) + + self.n_local_heads = self.n_heads // self.num_devices_per_group + self.n_local_kv_heads = self.n_kv_heads // self.num_devices_per_group + + self.arch_name = configuration.arch_name + # TODO: Fix this once all-gather supports < tile_size + if self.TG: + weight = torch.zeros(1, 32, 8, 32) + for i in range(32): + col = i % 4 # This determines which group of 8 to select + weight[:, i, :, col * 8 : (col + 1) * 8] = torch.eye(8) + + self.slice_mat = ttnn.from_torch( + weight, + dtype=ttnn.bfloat4_b, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), + ) + user_selection_matrix = torch.eye(8, 8) + user_selection_matrix = torch.nn.functional.pad(user_selection_matrix, (0, 24), "constant", 0) # (8, 32) + user_selection_matrix = [user_selection_matrix] * 4 + user_selection_matrix = torch.block_diag(*user_selection_matrix) # (32, 128) + self.user_selection_matrix = ttnn.from_torch( + user_selection_matrix, + dtype=ttnn.bfloat4_b, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + self.dtype = dtype + + self.max_seq_len = configuration.max_seq_len + self.grid_size = configuration.max_grid_size + + self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 + self.compute_kernel_config_hifi2_fp16 = configuration.compute_kernel_config_hifi2_fp16 + + self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + + self.transformation_mats = transformation_mats + + self.model_config = configuration.get_model_config() + self.ccl_topology = configuration.ccl_topology() + self.is_multichip = configuration.is_multichip + self.activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( + decoder_id=layer_num, tensor=TensorGroup.ACTIVATION + ) + self.wqkv_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( + decoder_id=layer_num, tensor=TensorGroup.WQKV + ) + self.wo_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( + decoder_id=layer_num, tensor=TensorGroup.WO + ) + self.kv_cache_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( + decoder_id=layer_num, tensor=TensorGroup.KV_CACHE + ) + self.li_qkv_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( + decoder_id=layer_num, op=OpGroup.LI_QKV_DECODE, configuration=configuration + ) + self.sdpa_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( + decoder_id=layer_num, op=OpGroup.SDPA_DECODE, configuration=configuration + ) + self.li_o_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( + decoder_id=layer_num, op=OpGroup.LI_O_DECODE, configuration=configuration + ) + self.sdpa_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( + decoder_id=layer_num, op=OpGroup.SDPA_PREFILL, configuration=configuration + ) + self.li_qkv_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( + decoder_id=layer_num, op=OpGroup.LI_QKV_PREFILL, configuration=configuration + ) + self.li_o_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( + decoder_id=layer_num, op=OpGroup.LI_O_PREFILL, configuration=configuration + ) + + layer_name = configuration.get_state_dict_prefix(self.__class__.__name__, layer_num) + if configuration.dummy_weights or (weight_cache_path is None): + cache_name = lambda _: None + else: + cache_name = lambda name: weight_cache_path / (f"{layer_name}.{name}") + + wq_str = f"{layer_name}.wq" + wk_str = f"{layer_name}.wk" + wv_str = f"{layer_name}.wv" + wo_str = f"{layer_name}.wo" + q_norm_str = f"{layer_name}.q_norm" + k_norm_str = f"{layer_name}.k_norm" + + # Initialize bias tensors as None + self.wqkv_bias_decode = None + self.wqkv_bias_prefill = None + + # Create combined QKV bias if present in state dict + if f"{wq_str}.bias" in self.state_dict: + qkv_bias = torch.concat( + [ + torch.concat( + [ + torch.chunk(self.state_dict[f"{wq_str}.bias"], configuration.num_devices)[i], + torch.chunk(self.state_dict[f"{wk_str}.bias"], configuration.num_devices)[i], + torch.chunk(self.state_dict[f"{wv_str}.bias"], configuration.num_devices)[i], + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ) + # Prefill can use broadcasting on the bias add so wants a 1d tensor + self.wqkv_bias_prefill = ttnn.as_tensor( + qkv_bias, + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("wqkv_bias_prefill_sharded"), + ) + # as_tensor returns (32, dim) which is incorrect, this reshape updates the padded size to the correct size + self.wqkv_bias_prefill = ttnn.reshape( + self.wqkv_bias_prefill, + (1, 1, 1, self.wqkv_bias_prefill.shape[-1]), + (1, 1, self.wqkv_bias_prefill.shape[-2], self.wqkv_bias_prefill.shape[-1]), + ) + + # Broadcasting does not seem to be supported inside execute_trace so expand to the whole batch size + # Create a list of bias tensors for each multiple of tile_size up to max_batch_size + self.wqkv_bias_decode = [] + for batch_size in range( + configuration.tile_size, + configuration.tile_padded_batch_rows + configuration.tile_size, + configuration.tile_size, + ): + qkv_bias_decode = qkv_bias.unsqueeze(0).expand(batch_size, -1) + bias_tensor = ttnn.as_tensor( + qkv_bias_decode, + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name(f"wqkv_bias_decode_sharded_{batch_size}"), + ) + self.wqkv_bias_decode.append(bias_tensor) + + # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices + assert self.n_heads % self.num_devices_per_group == 0 + assert self.n_kv_heads % self.num_devices_per_group == 0 + assert configuration.qkv_size % self.num_devices_per_group == 0 + assert configuration.dim % self.num_devices_per_group == 0 + + # wqkv: 4096 x 3072 (2 devices): width-sharded on 12 banks, 3072 over 12 banks. + wqkv_mem_config = configuration.create_dram_sharded_mem_config( + configuration.dim, configuration.qkv_size // configuration.num_devices + ) + + qkv_list = [] + for i in range(self.num_devices_per_group): + # Chunk weights + wq_selected = torch.chunk(self.state_dict[f"{wq_str}.weight"], self.num_devices_per_group, dim=0)[i] + wk_selected = torch.chunk(self.state_dict[f"{wk_str}.weight"], self.num_devices_per_group, dim=0)[i] + wv_selected = torch.chunk(self.state_dict[f"{wv_str}.weight"], self.num_devices_per_group, dim=0)[i] + + # Transpose the selected chunks + wq = torch.transpose(wq_selected, -2, -1) + wk = torch.transpose(wk_selected, -2, -1) + wv = torch.transpose(wv_selected, -2, -1) + + qkv = torch.cat([wq, wk, wv], dim=-1) + qkv_list.append(qkv) + + qkv_cat = torch.cat(qkv_list, dim=-1).unsqueeze(0).unsqueeze(0) + + self.wqkv = ttnn.as_tensor( + qkv_cat, + dtype=self.wqkv_dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG if self.TG else wqkv_mem_config, + mesh_mapper=ttnn.ShardTensor2dMesh( + self.mesh_device, dims=(3, 2) if self.TG else (2, 3), mesh_shape=configuration.cluster_shape + ), + cache_file_name=cache_name("wqkv_sharded_2d"), + ) + + def norm_reshard(x, norm, mode): + """Hack until RMSNorm supports height-sharded output config""" + if mode == "decode": + mem_cfg = x.memory_config() + x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG, dtype=x.dtype) + x = norm(x, mode) + if mode == "decode": + x = ttnn.to_memory_config(x, mem_cfg, dtype=x.dtype) + return x + + if f"{q_norm_str}.weight" in self.state_dict: + fn_q_norm = RMSNorm( + device=self.mesh_device, + dim=self.head_dim, + eps=configuration.norm_eps, + state_dict=self.state_dict, + state_dict_prefix=None, # we already prefix q_norm_str + weight_cache_path=None if configuration.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key=q_norm_str, + add_unit_offset=self.rms_norm_add_unit_offset, + is_distributed=False, + sharded_program_config=None, # FIXME: add height-sharded support. self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=None, # FIXME: add height-sharded support. self.model_config["CREATE_QKV_DECODE_SHARD"] + ) + self.q_norm = lambda x, mode: norm_reshard(x, fn_q_norm, mode) + else: + self.q_norm = lambda x, mode: x + + if f"{k_norm_str}.weight" in self.state_dict: + fn_k_norm = RMSNorm( + device=self.mesh_device, + dim=self.head_dim, + eps=configuration.norm_eps, + state_dict=self.state_dict, + state_dict_prefix=None, # we already prefix k_norm_str + weight_cache_path=None if configuration.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key=k_norm_str, + add_unit_offset=self.rms_norm_add_unit_offset, + is_distributed=False, + sharded_program_config=None, # FIXME: add height-sharded support. self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=None, # FIXME: add height-sharded support. self.model_config["CREATE_QKV_DECODE_SHARD"], + ) + self.k_norm = lambda x, mode: norm_reshard(x, fn_k_norm, mode) + else: + self.k_norm = lambda x, mode: x + + # For ring topology we can use all gather matmul for wo + self.use_fused_all_gather_matmul = self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] + pt_wo = self.state_dict[f"{wo_str}.weight"].transpose(-1, -2).unsqueeze(0).unsqueeze(0) + + wo_mem_config = configuration.create_dram_sharded_mem_config( + (configuration.n_heads * configuration.head_dim) // configuration.num_devices, configuration.dim + ) + + self.wo = ttnn.as_tensor( + pt_wo, + dtype=self.wo_dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG if (self.use_fused_all_gather_matmul or self.TG) else wo_mem_config, + mesh_mapper=ttnn.ShardTensor2dMesh( + self.mesh_device, + dims=(2, 3) if (self.use_fused_all_gather_matmul or self.TG) else (3, 2), + mesh_shape=configuration.cluster_shape, + ), + cache_file_name=( + cache_name("wo_width_sharded_2d") if (self.use_fused_all_gather_matmul or self.TG) else cache_name("wo") + ), + ) + if not use_paged_kv_cache: + # vLLM provides its own kv cache + self.init_kv_cache(configuration, weight_cache_path) + + if configuration.query_pre_attn_scalar is not None: + self.scale = configuration.query_pre_attn_scalar**-0.5 + else: + self.scale = self.head_dim**-0.5 + + def init_kv_cache(self, configuration, weight_cache_path): + """ + Generates empty KV cache and pushed to device memory + """ + + if self.paged_attention_config: + cache_k = torch.zeros( + ( + self.paged_attention_config.max_num_blocks, + self.n_local_kv_heads, + self.paged_attention_config.block_size, + self.head_dim, + ) + ) + cache_v = torch.zeros( + ( + self.paged_attention_config.max_num_blocks, + self.n_local_kv_heads, + self.paged_attention_config.block_size, + self.head_dim, + ) + ) + else: + cache_k = torch.zeros( + ( + self.batch_size_per_device_group, + self.n_local_kv_heads, + self.max_seq_len, + self.head_dim, + ) + ) + cache_v = torch.zeros( + ( + self.batch_size_per_device_group, + self.n_local_kv_heads, + self.max_seq_len, + self.head_dim, + ) + ) + + self.layer_past = [ + ttnn.as_tensor( + k_or_v, + dtype=self.kv_cache_dtype, + layout=self.model_config["ATTN_W_LAYOUT_TILE"], + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + cache_file_name=( + f"{weight_cache_path}/kvcache_{k_or_v.shape}" + if weight_cache_path and not configuration.dummy_weights + else None + ), + ) + for k_or_v in [cache_k, cache_v] + ] + + def forward_decode( + self, + x: ttnn.Tensor, + current_pos, + rot_mats=None, + page_table=None, + kv_cache=None, + ) -> ttnn.Tensor: + """ + x: (seq_len, 1, batch, dim) + current_pos: (batch_size), current token position in the sequence for each user + """ + + ### + # QKV matmuls + # Use HiFi2 for DRAM-sharded matmuls as they are otherwise flop-bound. Loses 1 bit of activation precision. + ### + + xqkv_fused_sharded = ttnn.linear( + x, + self.wqkv, + # bias=self.wqkv_bias, + memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, + program_config=self.model_config["XQKV_DECODE_PROGCFG"], + compute_kernel_config=self.li_qkv_decode_compute_kernel_cfg, + dtype=self.ccl_dtype if self.TG else self.activation_dtype or ttnn.bfloat16, + ) + # FIXME: File bug against dram-sharded matmuls with bias + if self.wqkv_bias_decode: + # select the bias tensor based on the number of tiles in the rows + # WARNING: must not change the batch size between compiling and executing a trace + num_tiles = int(math.ceil(xqkv_fused_sharded.shape[-2] / self.tile_size)) + xqkv_fused_sharded = xqkv_fused_sharded + self.wqkv_bias_decode[num_tiles - 1] + + ttnn.deallocate(x) + xqkv_fused = tt_all_reduce( + xqkv_fused_sharded, + self.mesh_device, + cluster_axis=1, + num_reduce_scatter_links=self.num_reduce_scatter_links, + num_all_gather_links=self.num_all_gather_links, + memory_config=self.model_config["QKV_OUT_GATHERED_MEMCFG"](list(self.mesh_device.shape)[1]), + sharded=True, + dtype=self.ccl_dtype, + topology=self.ccl_topology, + ) + + if self.TG: + # TODO: Slice the fused_query_key_value tensor get batch=8 + xqkv_fused = ttnn.matmul( + self.slice_mat, + xqkv_fused, + dtype=ttnn.bfloat16, + memory_config=self.model_config["CREATE_HEAD_INPUT_MEMCFG"], + ) + else: + # bfloat16 is required by nlp_create_qkv_heads_decode + xqkv_fused = ttnn.sharded_to_interleaved(xqkv_fused_sharded, ttnn.L1_MEMORY_CONFIG, ttnn.bfloat16) + + ttnn.deallocate(xqkv_fused_sharded) + + # Reshape such that true unpadded batch is tracked in shape + fqkv_shape = xqkv_fused.shape + xqkv_fused = ttnn.reshape( + xqkv_fused, (1, 1, self.batch_size_per_device_group, fqkv_shape[3]), (1, 1, 32, fqkv_shape[3]) + ) + + ### + # Reshape and rotary embeddings + ### + ( + q_heads_pre_rot_1BQD, + k_heads_pre_rot_1BKD, + v_heads_1BKD, + ) = ttnn.experimental.nlp_create_qkv_heads_decode( + xqkv_fused, + num_heads=self.n_local_heads, + num_kv_heads=self.n_local_kv_heads, + memory_config=self.model_config["CREATE_QKV_DECODE_SHARD"], + ) + + q_heads_pre_rot_1BQD = self.q_norm(q_heads_pre_rot_1BQD, mode="decode") + k_heads_pre_rot_1BKD = self.k_norm(k_heads_pre_rot_1BKD, mode="decode") + + ttnn.deallocate(xqkv_fused) + + # Q Rotary Embeddings + q_heads_1BQD = ttnn.experimental.rotary_embedding_llama( + q_heads_pre_rot_1BQD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True + ) + + # K Rotary Embeddings + k_heads_1BKD = ttnn.experimental.rotary_embedding_llama( + k_heads_pre_rot_1BKD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True + ) + + ttnn.deallocate(q_heads_pre_rot_1BQD) + ttnn.deallocate(k_heads_pre_rot_1BKD) + + ### + # KV update + ### + if kv_cache: + keys = kv_cache[0] + values = kv_cache[1] + else: + keys = self.layer_past[0] + values = self.layer_past[1] + # k_heads, [seqlen, n_kv_heads, bsz, head_dim] + # v_heads [seqlen, n_kv_heads, bsz, head_dim] + # keys, [max_batch_size, n_kv_heads // configuration.num_devices, max_seq_len, head_dim] + ttnn.experimental.paged_update_cache(keys, k_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table) + ttnn.experimental.paged_update_cache( + values, v_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table + ) + + ttnn.deallocate(k_heads_1BKD) + ttnn.deallocate(v_heads_1BKD) + + # NOTE: Varying the batch size will result in slightly different outputs. + # For example, a prompt w/ 1 user vs, the same prompt repeated N times for N users, will produce different outputs + # This is because the SDPA op in decode mode has different number of reductions depending on batch size + # Which leads to slightly different outputs from attention (due to accumulated errors) + if page_table: + attn_output_1G4D = ttnn.transformer.paged_scaled_dot_product_attention_decode( + q_heads_1BQD, + keys, + values, + cur_pos_tensor=current_pos, + page_table_tensor=page_table, + scale=self.scale, + program_config=self.model_config["SDPA_DECODE_PROGCFG"], + compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + else: + attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode( + q_heads_1BQD, + keys, + values, + cur_pos_tensor=current_pos, + scale=self.scale, + program_config=self.model_config["SDPA_DECODE_PROGCFG"], + compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, + memory_config=ttnn.DRAM_MEMORY_CONFIG, # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG? + ) + + ttnn.deallocate(q_heads_1BQD) + + attn_output_11BH = ttnn.to_memory_config( + attn_output_1G4D, + memory_config=self.model_config["SCORES_BATCHED_MM_OUTPUT_MEMCFG"](self.batch_size_per_device_group), + ) + attn_output_cat = ttnn.experimental.nlp_concat_heads_decode( + attn_output_11BH, + num_heads=self.n_local_heads, + ) + ttnn.deallocate(attn_output_11BH) + ttnn.deallocate(attn_output_1G4D) + + if self.use_fused_all_gather_matmul: + attn_output_cat = ttnn.to_memory_config( + attn_output_cat, self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"] + ) + _, dense_out_sharded, _ = ttnn.experimental.all_gather_matmul( + attn_output_cat, + self.wo, + dim=3, + all_gather_core_grid_offset=(0, 4), + num_links=1, + program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"], + compute_kernel_config=self.li_o_decode_compute_kernel_cfg, + memory_config_ag=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"], + memory_config_mm=self.model_config["DECODE_RESIDUAL_MEMCFG"], + ) + ttnn.deallocate(attn_output_cat) + dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"]) + return dense_out_sharded + + else: + attn_output = tt_all_gather( + attn_output_cat, + self.mesh_device, + dim=2, + cluster_axis=1, + num_links=2, + memory_config=self.model_config["GATHER_USERS_MEMCFG"](list(self.mesh_device.shape)[1]), + sharded=True, + # dtype=self.ccl_dtype, # Running bf16 until we have SDPA output bfp8 df; otherwise we have two sharded to interleaved/interleaved to sharded conversions + ) + if self.TG: + attn_output = ttnn.to_memory_config(attn_output, ttnn.L1_MEMORY_CONFIG) + # user_selection_matrix = [1, 1, 32, 128] + # user_selection_matrix @ activation -> [1, 1, 32, 128] * [1, 1, 128, 2048] -> [1, 1, 32, 2048] + attn_output = ttnn.matmul( + self.user_selection_matrix, + attn_output, + core_grid=ttnn.CoreGrid(y=4, x=8), + dtype=ttnn.bfloat16, + memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, + ) + + # TODO: Fix this once self.TG supports dram-sharded matmuls + dense_out_sharded = ttnn.matmul( + attn_output, + self.wo, + core_grid=ttnn.CoreGrid(y=4, x=8) if self.TG else None, + program_config=self.model_config["ATTN_OUTPUT_PROGCFG"] if not self.TG else None, + memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b if self.TG else None, + compute_kernel_config=self.li_o_decode_compute_kernel_cfg, + ) + + ttnn.deallocate(attn_output_cat) + + # All reduce + dense_out_reduced = tt_all_reduce( + dense_out_sharded, + self.mesh_device, + cluster_axis=0, + num_reduce_scatter_links=self.num_reduce_scatter_links, + num_all_gather_links=self.num_all_gather_links, + dim=0 if (self.TG and self.hidden_size < 8192) else 3, + topology=self.ccl_topology, + memory_config=( + ( + self.model_config["SELF_OUT_REDUCE_SCATTER_MEMCFG"] + if self.hidden_size == 8192 + else self.model_config["SELF_OUT_GATHERED_MEMCFG"](list(self.mesh_device.shape)[0]) + ) + if self.TG + else self.model_config["DECODE_RESIDUAL_MEMCFG"] + ), + sharded=True, + dtype=self.ccl_dtype, + use_composite=True if self.hidden_size == 8192 else False, + ) + + if not self.TG: + dense_out_reduced = ttnn.to_memory_config( + dense_out_reduced, self.model_config["DECODE_RESIDUAL_MEMCFG"] + ) + + return dense_out_reduced + + def forward_prefill( + self, + x_11SH, + rot_mats, + user_id: int = 0, + page_table=None, + chunk_page_table=None, + chunk_start_idx=None, + kv_cache=None, + ): + seq_len = x_11SH.shape[-2] + assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" + ### + # QKV matmuls + ### + + # reshaping long sequence to matmul fit on device + if seq_len > self.MAX_QKV_MM_SEQ_LEN: + if seq_len % self.MAX_QKV_MM_SEQ_LEN != 0: + raise ValueError(f"seq_len {seq_len} must be divisible by {self.MAX_QKV_MM_SEQ_LEN}") + x_11SH = ttnn.reshape(x_11SH, [1, seq_len // self.MAX_QKV_MM_SEQ_LEN, self.MAX_QKV_MM_SEQ_LEN, -1]) + + xqkv_fused = ttnn.linear( + x_11SH, + self.wqkv, + dtype=self.ccl_dtype if self.TG else self.activation_dtype or ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.li_qkv_prefill_compute_kernel_cfg, + program_config=self.model_config["XQKV_PREFILL_PROGCFG"](seq_len), + ) + + # FIXME: surely ttnn.linear bias should work? + if self.wqkv_bias_prefill is not None: + xqkv_fused = xqkv_fused + self.wqkv_bias_prefill + + xqkv_fused = tt_all_reduce( + xqkv_fused, + self.mesh_device, + cluster_axis=1, + num_reduce_scatter_links=self.num_reduce_scatter_links, + num_all_gather_links=self.num_all_gather_links, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.ccl_dtype, + ) + + if seq_len > self.MAX_QKV_MM_SEQ_LEN: + xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) + + ttnn.deallocate(x_11SH) + + # split qkv into heads + ( + q_heads_1QSD_pre_rot, + k_heads_1KSD_pre_rot, + v_heads_1VSD, + ) = ttnn.experimental.nlp_create_qkv_heads( + xqkv_fused, + num_heads=self.n_local_heads, + num_kv_heads=self.n_local_kv_heads, + transpose_k_heads=False, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + q_heads_1QSD_pre_rot = self.q_norm(q_heads_1QSD_pre_rot, mode="prefill") + k_heads_1KSD_pre_rot = self.k_norm(k_heads_1KSD_pre_rot, mode="prefill") + + ttnn.deallocate(xqkv_fused) + + ### + # Rotary embeddings + ### + + if q_heads_1QSD_pre_rot.dtype != ttnn.bfloat16: # Rotary embeddings require bfloat16 inputs + q_heads_1QSD_pre_rot = ttnn.typecast(q_heads_1QSD_pre_rot, dtype=ttnn.bfloat16) + + q_heads_1QSD = ttnn.experimental.rotary_embedding_llama( + q_heads_1QSD_pre_rot, + rot_mats[0], + rot_mats[1], + self.transformation_mats["prefill"], + is_decode_mode=False, + ) + ttnn.deallocate(q_heads_1QSD_pre_rot) + + if k_heads_1KSD_pre_rot.dtype != ttnn.bfloat16: # Rotary embeddings require bfloat16 inputs + k_heads_1KSD_pre_rot = ttnn.typecast(k_heads_1KSD_pre_rot, dtype=ttnn.bfloat16) + + k_heads_1KSD = ttnn.experimental.rotary_embedding_llama( + k_heads_1KSD_pre_rot, + rot_mats[0], + rot_mats[1], + self.transformation_mats["prefill"], + is_decode_mode=False, + ) + ttnn.deallocate(k_heads_1KSD_pre_rot) + + # Fill KV-Cache + if kv_cache: + keys_BKSD, values_BKSD = kv_cache[0], kv_cache[1] + else: + keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1] + k_heads_1KSD_8b = ttnn.typecast(k_heads_1KSD, dtype=keys_BKSD.dtype) + ttnn.deallocate(k_heads_1KSD) + + # sharding k_fill to deal with update_cache memory limitation + if seq_len >= self.min_kv_prefill_shard_seqlen and not self.TG and not page_table: + k_fill = ttnn.interleaved_to_sharded(k_heads_1KSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) + else: + k_fill = k_heads_1KSD_8b + + v_heads_1VSD_8b = ttnn.typecast(v_heads_1VSD, dtype=values_BKSD.dtype) + + ttnn.deallocate(v_heads_1VSD) + + # sharding v_fill to deal with update_cache memory limitation + if seq_len >= self.min_kv_prefill_shard_seqlen and not self.TG and not page_table: + v_fill = ttnn.interleaved_to_sharded(v_heads_1VSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) + else: + v_fill = v_heads_1VSD_8b + + if self.TG: + k_fill = self.prefill_prepare_tensor_for_kv_cache(k_fill, user_id) + v_fill = self.prefill_prepare_tensor_for_kv_cache(v_fill, user_id) + if page_table: + # In the case that the tokens have been padded along the seq len dimension, we need to fill the cache with the unpadded k/v values. + # Assume that the page table does not have padding, so we can use it to get the unpadded page len. + block_size = keys_BKSD.shape[2] + # If chunked prefill, use chunk_page_table if given, otherwise use page_table. + fill_page_table = chunk_page_table if chunk_page_table is not None else page_table + + page_len = fill_page_table.shape[1] * block_size + k_fill_sliced = k_fill[:, :, :page_len, :] if page_len < k_fill.shape[2] else k_fill + v_fill_sliced = v_fill[:, :, :page_len, :] if page_len < v_fill.shape[2] else v_fill + ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill_sliced, fill_page_table, batch_idx=user_id) + ttnn.experimental.paged_fill_cache(values_BKSD, v_fill_sliced, fill_page_table, batch_idx=user_id) + else: + ttnn.fill_cache( + keys_BKSD, + k_fill, + user_id % self.batch_size_per_device_group, + ) + ttnn.fill_cache( + values_BKSD, + v_fill, + user_id % self.batch_size_per_device_group, + ) + + if seq_len >= self.min_kv_prefill_shard_seqlen and not self.TG and not page_table: + ttnn.deallocate(k_fill) + ttnn.deallocate(v_fill) + + # SDPA + q_heads_1QSD_8b = ttnn.typecast(q_heads_1QSD, dtype=self.activation_dtype or ttnn.bfloat8_b) + ttnn.deallocate(q_heads_1QSD) + + if chunk_start_idx is not None: + attn_output_84SD = ttnn.transformer.chunked_scaled_dot_product_attention( + q_heads_1QSD_8b, + keys_BKSD, + values_BKSD, + page_table, + chunk_start_idx, + compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg, + program_config=self.model_config["SDPA_PROGCFG"](seq_len), + ) + else: + attn_output_84SD = ttnn.transformer.scaled_dot_product_attention( + q_heads_1QSD_8b, + k_heads_1KSD_8b, + v_heads_1VSD_8b, + is_causal=True, + scale=self.scale, + compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg, + program_config=self.model_config["SDPA_PROGCFG"](seq_len), + ) + + # deallocate keys and values + ttnn.deallocate(q_heads_1QSD_8b) + ttnn.deallocate(k_heads_1KSD_8b) + ttnn.deallocate(v_heads_1VSD_8b) + + attn_output_1QSD = ttnn.reshape(attn_output_84SD, [1, self.n_local_heads, -1, self.head_dim]) + + ### + # Output matmul + ### + attn_output_11SH = ttnn.experimental.nlp_concat_heads( + attn_output_1QSD, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ttnn.deallocate(attn_output_1QSD) + # reshaping long sequence to matmul fit on device + if seq_len > 1024: + attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // 1024, 1024, -1]) + + # Non fused All Gather Matmul + if self.use_fused_all_gather_matmul: # is true for Ring topology + attn_output_11SH = ttnn.all_gather( + attn_output_11SH, + dim=3, + num_links=1, + topology=self.ccl_topology, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + output_11SH = ttnn.linear( + attn_output_11SH, + self.wo, + compute_kernel_config=self.li_o_prefill_compute_kernel_cfg, + dtype=self.activation_dtype or ttnn.bfloat8_b, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + program_config=self.model_config["WO_PREFILL_PROGCFG"](seq_len), + ) + + if seq_len > 1024: + output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) + ttnn.deallocate(attn_output_11SH) + + # Reduce-scatter + if not self.use_fused_all_gather_matmul: + output_11SH = tt_all_reduce( + output_11SH, + self.mesh_device, + cluster_axis=0, + dim=0 if self.TG else 3, + num_reduce_scatter_links=self.num_reduce_scatter_links, + num_all_gather_links=self.num_all_gather_links, + topology=self.ccl_topology, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.ccl_dtype, + ) + + return output_11SH + + def forward( + self, + x, + current_pos, + rot_mats=None, + user_id=0, + mode="decode", + page_table=None, + chunk_page_table=None, + chunk_start_idx=None, + kv_cache=None, + ): + if mode == "prefill": + return self.forward_prefill( + x, + rot_mats, + user_id, + page_table=page_table, + chunk_page_table=chunk_page_table, + chunk_start_idx=chunk_start_idx, + kv_cache=kv_cache, + ) + else: + return self.forward_decode(x, current_pos, rot_mats, page_table=page_table, kv_cache=kv_cache) + + def prefill_prepare_tensor_for_kv_cache(self, key_or_value_layer, user_id): + tensor_copy = ttnn.clone(key_or_value_layer) + # key_or_value_layer.deallocate(True) + # Get all tensors from multi-device tensor + tensors = ttnn.get_device_tensors(tensor_copy) + # Get only tensors from specific column chips + # Get every 4th tensor starting from user_id // 8 + single_column_tensors = tensors[user_id // self.batch_size_per_device_group :: 4] + # Create multi-device tensor + multi_device_tensor = ttnn.combine_device_tensors(single_column_tensors) + + return multi_device_tensor diff --git a/models/experimental/gemma3_1b/tt/lm_head.py b/models/experimental/gemma3_1b/tt/lm_head.py new file mode 100644 index 000000000000..3be020957904 --- /dev/null +++ b/models/experimental/gemma3_1b/tt/lm_head.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.ccl import tt_all_reduce + + +class LMHead(LightweightModule): + def __init__( + self, + args, + mesh_device, + dtype, + state_dict, + state_dict_prefix, + weight_cache_path, + max_columns_per_device, # too many columns per device lead to L1 OOM + ): + super().__init__() + self.args = args + self.mesh_device = mesh_device + self.dtype = dtype + self.vocab_size = args.vocab_size + self.padded_vocab_size = args.padded_vocab_size + self.num_devices = args.num_devices + + size_per_device = self.vocab_size // self.num_devices + + if args.is_galaxy: + size_per_device = self.padded_vocab_size // self.num_devices + num_splits = math.ceil(size_per_device / max_columns_per_device) + + split_sizes = [min(size_per_device, max_columns_per_device)] * (num_splits - 1) + split_sizes.append(size_per_device - sum(split_sizes)) # remaining columns + + # Split the output weights + torch_output_weights = state_dict[f"{state_dict_prefix}output.weight"].permute(1, 0) + + self.output_weights = [] + if args.is_galaxy: + cache_file_name = ( + None if args.dummy_weights else weight_cache_path / f"output_lm_head_{num_splits}_split_shard_0" + ) + padded_lm_head = torch.zeros(1, 1, args.dim, self.padded_vocab_size) + padded_lm_head[:, :, :, : self.vocab_size] = torch_output_weights + + memory_config = ( + ttnn.DRAM_MEMORY_CONFIG + if args.dim == 2048 + else args.create_dram_sharded_mem_config(k=args.dim // 4, n=self.padded_vocab_size // 8) + ) + self.output_weights.append( # (2k, 16k) 128* 1024 + ttnn.as_tensor( + padded_lm_head, + device=mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(3, 2), mesh_shape=args.cluster_shape), + layout=ttnn.TILE_LAYOUT, + dtype=dtype, + memory_config=memory_config, + cache_file_name=cache_file_name, + ) + ) + else: + for i, split_size in enumerate(split_sizes): + # Create a list to store the split tensors for each device + device_splits = [] + for device in range(self.num_devices): + start = device * size_per_device + sum(split_sizes[:i]) + end = start + split_size + device_splits.append(torch_output_weights[:, start:end]) + + # Concatenate the splits from all devices + combined_split = torch.cat(device_splits, dim=-1) + + cache_file_name = ( + None + if args.dummy_weights + else weight_cache_path / f"output_lm_head_{num_splits}_split_shard_{i}_{combined_split.shape[-1]}" + ) + memory_config = args.create_dram_sharded_mem_config( + k=args.dim, n=math.ceil(combined_split.shape[-1] / self.num_devices) + ) + self.output_weights.append( + ttnn.as_tensor( + combined_split, + device=mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + layout=ttnn.TILE_LAYOUT, + dtype=dtype, + memory_config=memory_config, + cache_file_name=cache_file_name, + ) + ) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=False, + packer_l1_acc=True, + ) + if args.is_galaxy: + self.program_configs = [ + ( + None + if args.dim == 2048 + else args.dram_matmul_config( + args.tile_padded_batch_rows, # (8k, 128k) -> (2k, 16k) + args.dim // 4, + 16 * 1024, + args.lm_head_core_grid.num_cores, + ) + ) + ] + + else: + self.program_configs = [ + args.dram_matmul_config( + args.tile_padded_batch_rows, + args.dim, + split_size, + args.lm_head_core_grid.num_cores, + ) + for split_size in split_sizes + ] + + def forward(self, x: ttnn.Tensor): + outputs = [] + for weight, pc in zip(self.output_weights, self.program_configs): + output = ttnn.linear( + x, + weight, + compute_kernel_config=self.compute_kernel_config, + program_config=pc, + memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + ) + outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.L1_MEMORY_CONFIG)) + + # Concatenate the outputs + output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + + output = tt_all_reduce( + output, + mesh_device=self.mesh_device, + cluster_axis=1, + dim=3 if self.args.is_galaxy else 0, + num_reduce_scatter_links=self.args.num_reduce_scatter_links, + num_all_gather_links=self.args.num_all_gather_links, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=self.args.ccl_dtype, + sharded=False, + use_composite=True, + ) + + return output diff --git a/models/experimental/gemma3_1b/tt/mlp.py b/models/experimental/gemma3_1b/tt/mlp.py new file mode 100644 index 000000000000..9893ec2440e4 --- /dev/null +++ b/models/experimental/gemma3_1b/tt/mlp.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.ccl import tt_all_reduce +from models.tt_transformers.tt.common import pad_to_size +from models.tt_transformers.tt.model_config import OpGroup, TensorGroup + + +class MLP(LightweightModule): + def __init__( + self, mesh_device, args, state_dict, weight_cache_path, layer_num, dtype, model_config, state_dict_prefix=None + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.args = args + self.dim = args.dim + self.model_config = model_config + self.layer_num = layer_num + state_dict_prefix = state_dict_prefix or args.get_state_dict_prefix(self.__class__.__name__, layer_num) + torch_weight = lambda name: torch.transpose(self.state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1) + pad_hidden_dim = lambda tensor, dim: pad_to_size(tensor, dim=dim, size=args.hidden_dim) + # If pading was applied (e.g. via env var), add the unpadded hidden dim to the cache name to avoid loading incorrect weights + hidden_dim_string = f".hidden_dim_{args.hidden_dim}" if args.hidden_dim != args.unpadded_hidden_dim else "" + + if args.dummy_weights: + cache_name = lambda _: None + else: + cache_name = lambda name: weight_cache_path / f"{state_dict_prefix}.{name}{hidden_dim_string}" + + w1_w3_mem_config = args.create_dram_sharded_mem_config(args.dim, args.hidden_dim // args.num_devices) + w2_mem_config = args.create_dram_sharded_mem_config(args.hidden_dim // args.num_devices, args.dim) + + # TODO Clean up this code. With sharding, we load the normal weights and then shard them + as_sharded_tensor = lambda name, type, dims: ttnn.as_tensor( + pad_hidden_dim( + torch_weight(name[:2]), dims[0] if args.is_galaxy else dims[-1] + ), # Grab only the wX part of the name + dtype=type, + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=args.cluster_shape), + layout=ttnn.TILE_LAYOUT, + memory_config=( + ttnn.DRAM_MEMORY_CONFIG if args.is_galaxy else w2_mem_config if "w2" in name else w1_w3_mem_config + ), + cache_file_name=cache_name(name), + ) + + # Sharded weights + w1_dims = (-1, -2) if args.is_galaxy else (-2, -1) + w2_dims = (-2, -1) if args.is_galaxy else (-1, -2) + + layer_num = max(layer_num, 0) # cross_block uses the configutation of the first decoder + + ff1_3_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( + decoder_id=layer_num, tensor=TensorGroup.FF1_FF3 + ) + ff2_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( + decoder_id=layer_num, tensor=TensorGroup.FF2 + ) + + self.w1 = as_sharded_tensor( + "w1_sharded", ff1_3_dtype, dims=w1_dims + ) # bfp4 normally ok here but sub .99 pcc for llama 3.1 weights + self.w2 = as_sharded_tensor("w2_sharded", ff2_dtype, dims=w2_dims) + self.w3 = as_sharded_tensor("w3_sharded", ff1_3_dtype, dims=w1_dims) + + # Default activation is SILU + self.activation_type = self.args.mlp_activation_type + + def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: + """ + w1 -> gate_proj + w2 -> down_proj + w3 -> up_proj + HF reference: self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + """ + seq_len = x.shape[-2] + TG = self.args.is_galaxy + layer_num = max(self.layer_num, 0) # cross_block uses the configutation of the first decoder + activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( + decoder_id=layer_num, tensor=TensorGroup.ACTIVATION + ) + li_ff1_3_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( + decoder_id=layer_num, op=OpGroup.LI_FF1_FF3, configuration=self.args + ) + + if mode == "decode": # Sharded config + if TG: # TODO: Fix this when TG supports DRAM sharded matmuls + pc_1 = self.model_config["FF1_3_TG_PROGCFG"] if self.dim >= 4096 else None + pc_2 = self.model_config["FF2_TG_PROGCFG"] if self.dim >= 4096 else None + pc_3 = self.model_config["FF1_3_TG_PROGCFG"] if self.dim >= 4096 else None + else: + pc_1 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"] + pc_2 = self.model_config["DECODE_MLP_W2_PRG_CONFIG"] + pc_3 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"] + else: # Update the program configs based for prefill + if seq_len >= self.args.prefill_len_cutoff: # 512 if Blackhole, 1024 if Wormhole + # Reshape input to to fit on device and parallelize computation + x = ttnn.reshape(x, [1, seq_len // self.args.prefill_len_cutoff, self.args.prefill_len_cutoff, -1]) + pc_1 = self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"](seq_len) + pc_2 = self.model_config["PREFILL_MLP_W2_PRG_CONFIG"](seq_len) + pc_3 = self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"](seq_len) + + # In decode mode (seqlen <= 32) do DRAM sharded matmuls + # These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4 + memory_config = ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + w1_out = ttnn.linear( + x, + self.w1, + dtype=ttnn.bfloat8_b if TG else activation_dtype or ttnn.bfloat16, + core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_1 else None, + compute_kernel_config=li_ff1_3_compute_kernel_cfg, + program_config=pc_1, + memory_config=memory_config, + ) + + w3_out = ttnn.linear( + x, + self.w3, + dtype=ttnn.bfloat8_b if TG else activation_dtype or ttnn.bfloat16, + core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_3 else None, + compute_kernel_config=li_ff1_3_compute_kernel_cfg, + program_config=pc_3, + memory_config=memory_config, + ) + ttnn.deallocate(x) + + if TG: + # if mode == "decode" and self.dim!=8192: + # w1_out = ttnn.to_memory_config(w1_out, ttnn.DRAM_MEMORY_CONFIG) + # w3_out = ttnn.to_memory_config(w3_out, ttnn.DRAM_MEMORY_CONFIG) + if self.dim == 8192 or mode == "prefill": + input_mem_cfg = w1_out.memory_config() + w1_out = ttnn.reduce_scatter( + w1_out, + dim=3, + math_op=ttnn.ReduceType.Sum, + num_links=self.args.num_reduce_scatter_links, + cluster_axis=1, + mesh_device=self.mesh_device, + topology=ttnn.Topology.Linear, + memory_config=self.model_config["FF1_OUT_REDUCE_SCATTER_MEMCFG"] if mode == "decode" else None, + ) + w3_out = ttnn.reduce_scatter( + w3_out, + dim=3, + math_op=ttnn.ReduceType.Sum, + num_links=1, + cluster_axis=1, + mesh_device=self.mesh_device, + topology=ttnn.Topology.Linear, + memory_config=self.model_config["FF1_OUT_REDUCE_SCATTER_MEMCFG"] if mode == "decode" else None, + ) + else: + w1_out = tt_all_reduce( + w1_out, + self.mesh_device, + cluster_axis=1, + num_all_gather_links=2, + sharded=True if mode == "decode" else False, + topology=self.args.ccl_topology(), + memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None, + ) + w3_out = tt_all_reduce( + w3_out, + self.mesh_device, + cluster_axis=1, + num_all_gather_links=2, + sharded=True if mode == "decode" else False, + topology=self.args.ccl_topology(), + memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None, + ) + + w2_in = ttnn.mul( + w1_out, + w3_out, + input_tensor_a_activations=[self.activation_type], + dtype=activation_dtype or ttnn.bfloat8_b, + memory_config=w1_out.memory_config(), + ) + + if mode == "decode" and not TG: + # w2 may use a different core grid, this is a no-op if they already match + w2_in = ttnn.to_memory_config(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"]) + + ttnn.deallocate(w3_out) + ttnn.deallocate(w1_out) + + if TG and (self.dim == 8192 or mode == "prefill"): + w2_in = ttnn.all_gather( + w2_in, + 3, + num_links=2, + cluster_axis=1, + mesh_device=self.mesh_device, + topology=ttnn.Topology.Linear, + memory_config=input_mem_cfg, + ) + if mode == "decode": + w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG) + + li_ff2_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( + decoder_id=layer_num, op=OpGroup.LI_FF2, configuration=self.args + ) + w2_out = ttnn.linear( + w2_in, + self.w2, + compute_kernel_config=li_ff2_compute_kernel_cfg, + dtype=self.args.ccl_dtype if TG else activation_dtype or ttnn.bfloat16, + program_config=pc_2, + memory_config=memory_config, + core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, + ) + ttnn.deallocate(w2_in) + # if mode == "decode" and not TG: + # w2_out = ttnn.sharded_to_interleaved(w2_out, ttnn.DRAM_MEMORY_CONFIG) + w2_out_reduced = tt_all_reduce( + w2_out, + self.mesh_device, + cluster_axis=0, + dim=0 if (TG and self.dim < 8192) else 3, + num_reduce_scatter_links=self.args.num_reduce_scatter_links, + num_all_gather_links=self.args.num_all_gather_links, + sharded=(mode == "decode"), + memory_config=( + (self.model_config["FF2_OUT_REDUCE_SCATTER_MEMCFG"] if TG else w2_out.memory_config()) + if mode == "decode" + else ttnn.DRAM_MEMORY_CONFIG + ), + dtype=self.args.ccl_dtype, + use_composite=True if self.dim == 8192 else False, + topology=self.args.ccl_topology(), + ) + + # Ensure dim 0 and 1 are 1 + original_shape = w2_out_reduced.shape + w2_out_reduced = ttnn.reshape( + w2_out_reduced, (1, 1, original_shape[-4] * original_shape[-3] * original_shape[-2], original_shape[-1]) + ) + if mode == "decode": + w2_out_reduced = ttnn.to_memory_config( + w2_out_reduced, + self.model_config["SHARDED_ATTN_INPUT_MEMCFG"] if TG else self.model_config["DECODE_RESIDUAL_MEMCFG"], + ) + + # ttnn.deallocate(w2_out) + return w2_out_reduced From ade214f660c441b5dfca7e518ca83c46b87429b5 Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Tue, 29 Jul 2025 12:48:40 +0000 Subject: [PATCH 2/4] Add experimental Gemma-3-1b-it Model bringup --- .../gemma3_1b/tests/test_attention.py | 277 +++++++++++ .../gemma3_1b/tests/test_decoder.py | 198 ++++++++ .../experimental/gemma3_1b/tests/test_mlp.py | 98 ++++ .../gemma3_1b/tests/test_model.py | 334 ++++++++++++++ .../gemma3_1b/tests/test_rmsnorm.py | 116 +++++ models/experimental/gemma3_1b/tt/attention.py | 34 +- models/experimental/gemma3_1b/tt/decoder.py | 226 +++++++++ models/experimental/gemma3_1b/tt/lm_head.py | 22 +- models/experimental/gemma3_1b/tt/mlp.py | 20 +- models/experimental/gemma3_1b/tt/model.py | 432 ++++++++++++++++++ models/experimental/gemma3_1b/tt/rmsnorm.py | 177 +++++++ models/tt_transformers/tt/common.py | 30 +- models/tt_transformers/tt/load_checkpoints.py | 6 +- models/tt_transformers/tt/model_config.py | 98 +++- 14 files changed, 2026 insertions(+), 42 deletions(-) create mode 100644 models/experimental/gemma3_1b/tests/test_attention.py create mode 100644 models/experimental/gemma3_1b/tests/test_decoder.py create mode 100644 models/experimental/gemma3_1b/tests/test_mlp.py create mode 100644 models/experimental/gemma3_1b/tests/test_model.py create mode 100644 models/experimental/gemma3_1b/tests/test_rmsnorm.py create mode 100644 models/experimental/gemma3_1b/tt/decoder.py create mode 100644 models/experimental/gemma3_1b/tt/model.py create mode 100644 models/experimental/gemma3_1b/tt/rmsnorm.py diff --git a/models/experimental/gemma3_1b/tests/test_attention.py b/models/experimental/gemma3_1b/tests/test_attention.py new file mode 100644 index 000000000000..d488150ccd27 --- /dev/null +++ b/models/experimental/gemma3_1b/tests/test_attention.py @@ -0,0 +1,277 @@ +""" Test for Gemma-3-1b-it Attention """ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.experimental.gemma3_1b.tt.attention import Attention +from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs +from models.tt_transformers.tt.rope import RotarySetup +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + False, + ), + ids=( + "paged_attention", + "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (256,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_attention_inference( + max_seq_len, + batch_size, + paged_attention, + page_params, + mesh_device, + reset_seeds, +): + dtype = ttnn.bfloat16 + pcc = 0.99 + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + model_args.n_layers = 1 # For the unit test, just run a single layer + + state_dict = model_args.load_state_dict() + + first_layer_prefix = model_args.get_state_dict_prefix("Attention", 0) + "." + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_attention() + reference_model.load_state_dict(partial_state_dict) + + seq_len = 1 + + generation_start_pos = 0 + generation_length = 10 + all_tests_pass = True + + # Setup RoPE transformation matrices + rope_setup = RotarySetup( + mesh_device, + batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.rope_scaling_factor, + model_args.orig_context_len, + ) + + transformation_mats = rope_setup.get_both_trans_mats() + + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + tt_model = Attention( + mesh_device, + state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + layer_num=0, + dtype=dtype, + transformation_mats=transformation_mats, + configuration=model_args, + paged_attention_config=paged_attention_config, + ) + + cos, sin = precompute_freqs( + model_args.head_dim, + model_args.max_seq_len * 2, + model_args.rope_theta, + model_args.rope_scaling_factor, + model_args.orig_context_len, + ) + freqs_cis = torch.complex(cos, sin) + + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + for i in range(generation_length): + # 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1 + pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) # Qwen2.5 0.5B sees 0.1 to 2.1 + + tt_attention_input = pt_attention_input.clone() + + attention_input = model_args.prepare_residual_tensor_decode( + tt_attention_input, + model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"], + force_replicated=False if model_args.is_galaxy else True, + ) + + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) + + tt_out = tt_model( + attention_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) + # multi-device attention module returns replicated output + tt_out = ttnn.to_torch( + tt_out, + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), + ) + tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) + + # In this test all users have the same position (if using batch > 1) + freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) + + reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info(f"[pos={current_pos[0]}] Attention Passed!") + else: + logger.warning(f"[pos={current_pos[0]}] Attention Failed!") + all_tests_pass = False + + # Increment position + current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + check_kv_cache = True + if check_kv_cache: + # PyTorch output -------------------------------------------------------------------- + pytorch_layer_present = [ + reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] + reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] + ] + # TT hardware execution ------------------------------------------------------------- + if paged_attention: + tt_layer_present = [ + ( + ttnn.to_torch( + cache, + mesh_composer=ttnn.ConcatMesh2dToTensor( + mesh_device, + dims=(1, 3) if model_args.is_galaxy else (0, 1), + mesh_shape=model_args.cluster_shape, + ), + )[reverse_permutation][:, : model_args.n_kv_heads, :, : model_args.head_dim] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch_size, ... + ] + ) + for cache in tt_model.layer_past + ] + else: + tt_layer_present = [ + ttnn.to_torch( + cache, + mesh_composer=ttnn.ConcatMesh2dToTensor( + mesh_device, + dims=(1, 0) if model_args.is_galaxy else (0, 1), + mesh_shape=model_args.cluster_shape, + ), + )[:batch_size, :, :, :] + for cache in tt_model.layer_past + ] + for label, cache_pt, cache_tt in zip(["K", "V"], pytorch_layer_present, tt_layer_present): + cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + i + 1) + cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] + cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] + does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) + logger.info(f"{label} cache output: {output_pcc}") + if does_pass: + logger.info(f"{label} cache Passed!") + else: + logger.warning(f"{label} Cache Failed! PCC value is lower than {pcc}") + all_tests_pass = False + + if all_tests_pass: + logger.info("Attention output Passed!") + else: + logger.warning("Attention output Failed!") + assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3_1b/tests/test_decoder.py b/models/experimental/gemma3_1b/tests/test_decoder.py new file mode 100644 index 000000000000..4bce0d46159f --- /dev/null +++ b/models/experimental/gemma3_1b/tests/test_decoder.py @@ -0,0 +1,198 @@ +""" Test for Gemma-3-1b-it Decoder """ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import os +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.gemma3_1b.tt.decoder import TransformerBlock +from models.utility_functions import ( + comp_pcc, + comp_allclose, +) +from models.utility_functions import skip_for_grayskull +from models.tt_transformers.tt.common import PagedAttentionConfig +from models.tt_transformers.tt.rope import RotarySetup + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "paged_attention", + (True, False), + ids=("paged_attention", "default_attention"), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (256,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_decoder_inference( + max_seq_len, + batch_size, + paged_attention, + page_params, + mesh_device, + reset_seeds, +): + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + model_args.n_layers = 1 + + state_dict = model_args.load_state_dict() + + reference_model = model_args.reference_decoder() + + generation_start_pos = 0 + generation_length = 10 + all_tests_pass = True + + rope_setup = RotarySetup( + mesh_device, + batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.rope_scaling_factor, + model_args.orig_context_len, + ) + transformation_mats = rope_setup.get_both_trans_mats() + + # Prepare page table for paged attention + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + # Initialize TT model + tt_model = TransformerBlock( + args=model_args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + layer_num=0, + weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, + ) + + seqlen = 1 + + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + for i in range(generation_length): + pt_decode_input = (torch.rand(batch_size, seqlen, model_args.dim) * 2) - 1 + logger.info(f"[Decoder] Generating token {i}") + + tt_decode_input = pt_decode_input.clone() + + decode_input = model_args.prepare_residual_tensor_decode( + tt_decode_input, + # ttnn.DRAM_MEMORY_CONFIG, + model_args.model_config["DECODE_RESIDUAL_MEMCFG"], + ) + + # Get cos/sin matrices for the current position of each user + rot_mat_global = rope_setup.get_rot_mats(current_pos) + rot_mat_local = rope_setup.get_rot_mats(current_pos) + + # Run TT model + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=[rot_mat_global, rot_mat_local], + mode="decode", + page_table=page_table_tt, + ) + tt_out = ttnn.to_torch( + tt_out, + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), + ) + + tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) + + # Reference model + ref_output = reference_model(pt_decode_input, current_pos[0], None, mask=None) + + passing, pcc_message = comp_pcc(ref_output, tt_output_torch) + + logger.info(comp_allclose(ref_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("Decoder Block Passed!") + else: + logger.warning("Decoder Block Failed!") + all_tests_pass = False + + # Increment position + current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + if all_tests_pass: + logger.info(f"All {generation_length} decode iterations Passed!") + else: + logger.warning("One or more iterations of decode Failed!") + assert all_tests_pass, f"PCC value is lower than {0.99} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3_1b/tests/test_mlp.py b/models/experimental/gemma3_1b/tests/test_mlp.py new file mode 100644 index 000000000000..7e25632dfaf8 --- /dev/null +++ b/models/experimental/gemma3_1b/tests/test_mlp.py @@ -0,0 +1,98 @@ +""" Test for Gemma-3-1b-it MLP """ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger + +import torch +import pytest +import os +import ttnn + +from models.experimental.gemma3_1b.tt.mlp import MLP +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (1152,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_mlp_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs(device, max_batch_size=batch_size, max_seq_len=128) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + # # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "layers.0.feed_forward" + partial_state_dict = { + k[len(first_layer_prefix) + 1 :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = tt_model_args.reference_mlp() # Gemma3 MLP + reference_model.load_state_dict(partial_state_dict) + + tt_model = MLP( + mesh_device=device, + args=tt_model_args, + state_dict=state_dict, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + layer_num=0, + dtype=dtype, + model_config=tt_model_args.get_model_config(), + state_dict_prefix=first_layer_prefix, + ) + torch_input = torch.randn(1, 1, seq_len) + reference_output = reference_model(torch_input) + + tt_input = ttnn.from_torch( + torch_input, + device=device, + mesh_mapper=ttnn.ShardTensor2dMesh( + device, dims=(None, 3) if tt_model_args.is_galaxy else (None, None), mesh_shape=tt_model_args.cluster_shape + ), # When both dims are None, the mapper used is `ReplicateTensorToMesh` + dtype=dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + logger.info("Run MLP") + tt_output = tt_model(tt_input, mode) + + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor(device, dims=(1, 3), mesh_shape=tt_model_args.cluster_shape), + ) + + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch[0])) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info("MLP Passed!") + else: + logger.warning("MLP Failed!") + + assert passing, f"MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/experimental/gemma3_1b/tests/test_model.py b/models/experimental/gemma3_1b/tests/test_model.py new file mode 100644 index 000000000000..f0c3730eda57 --- /dev/null +++ b/models/experimental/gemma3_1b/tests/test_model.py @@ -0,0 +1,334 @@ +""" Test for Gemma-3-1b-it End-to-End Model""" +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import os +import ttnn +from models.tt_transformers.tt.common import ( + encode_prompt_hf, + sample_host, + PagedAttentionConfig, +) +from models.tt_transformers.tt.model_config import DecodersPrecision + +from models.experimental.gemma3_1b.tt.model import Gemma3Transformer +from models.utility_functions import ( + comp_pcc, + comp_allclose, +) +from models.utility_functions import skip_for_grayskull, skip_for_blackhole +from models.tt_transformers.tt.model_config import HfModelWrapper +from models.tt_transformers.tt.model_config import ModelArgs + +import re + + +def parse_chat_output(text): + pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" + matches = re.finditer(pattern, text, re.DOTALL) + return [(match.group("role"), match.group("message").strip()) for match in matches] + + +def display_chat(logger, conversation): + for role, message in conversation: + if role == "user": + logger.info(f"👤 User: {message}") + elif role == "assistant": + logger.info(f"🤖 Assistant: {message}") + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") +@pytest.mark.timeout(1800) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "weights, layers", + [ + ("instruct", None), + ], + ids=["full"], +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + False, + ), + ids=( + "paged_attention", + "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (256,), # For decode-only unit test, there's no need to run with large sequence lengths +) +@pytest.mark.parametrize( + "optimizations", + [ + lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), # poem + ], + ids=["accuracy"], +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_model_inference( + weights, + layers, + max_seq_len, + batch_size, + paged_attention, + page_params, + optimizations, + mesh_device, + reset_seeds, + request, + parse_chat=True, +): + run_ref_pt = True # Flag to run reference PyTorch model and compare PCC + + dtype = ttnn.bfloat16 + + test_id = request.node.callspec.id + + mode_accuracy = "accuracy" in test_id + instruct = True if weights == "instruct" else False + + model_args = ModelArgs( + mesh_device=mesh_device, + instruct=instruct, + optimizations=optimizations, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + ) + + # Expected PCC for the model + pcc = 0.86 + + # Number of decode iterations to run for the model + iterations = 20 + + if layers is not None: + model_args.n_layers = layers + state_dict = model_args.load_state_dict() + state_dict_prefix = model_args.get_state_dict_prefix("", None) + + prompts = ["Consider the sequence of prime numbers: 2, 3, 5, 7, count till 100"] * model_args.max_batch_size + + tokenizer = model_args.tokenizer + if instruct: + encoded_prompts = encode_prompt_hf(tokenizer=tokenizer, prompt_text=prompts[0]) + else: + encoded_prompts = [model_args.encode_prompt(prompt, instruct=False) for prompt in prompts] + + if run_ref_pt: + reference_transformer_model = model_args.reference_transformer(wrap=False) + reference_model = HfModelWrapper(reference_transformer_model, model_args.head_dim) + logger.info("Finished loading reference model.") + + # Embedding on host + embd = model_args.reference_embedding(reference_transformer_model) + else: + # Embedding on host + embd = model_args.reference_embedding() + generation_start_pos = 0 + generation_length = iterations + + page_table_tt = None + paged_attention_config = None + + # Prepare page table for paged attention + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if batch_size > 1 else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + # Load TTNN model + tt_model = Gemma3Transformer( + args=model_args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + paged_attention_config=paged_attention_config, + ) + + logger.info("Model and caches loaded.") + + if run_ref_pt: + all_tests_pass = True + + seqlen = 1 # Generating one token per user at a time + batch = model_args.max_batch_size + + # Select the first token from the prompts for initial decoding + encoded_prompts_tensor = torch.tensor(encoded_prompts).unsqueeze(0) # [:,0] + pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1) + tt_decode_input = pt_decode_input + + # Keep track of generated outputs to print out later + all_outputs = [] + if run_ref_pt: + all_outputs_ref = [] + + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + for i in range(generation_length): + logger.info(f"[Model] Generating token {i}") + + decode_input = model_args.prepare_residual_tensor_decode( + tt_decode_input, + model_args.model_config["DECODE_RESIDUAL_MEMCFG"], + ) + + # Get cos/sin matrices for the current position of each user + # rot_mats = tt_model.rope_setup.get_rot_mats(current_pos ) # TODO Fix for sliding window attention #TODO Fix for Gemma3 4B + # rot_mats_local = tt_model.rope_setup_local.get_rot_mats(current_pos) #TODO Fix for sliding window attention #TODO Fix for Gemma3 4B + rot_mats_global = tt_model.rope_setup.get_rot_mats(current_pos) # default + rot_mats_local = tt_model.rope_setup_local.get_rot_mats(current_pos) # default + rot_mats = [rot_mats_global, rot_mats_local] + + # Run TT model + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, # should contain both for slidig window and without it #TODO Fix for Gemma3 4B + mode="decode", + page_table=page_table_tt, + ) + + # Convert ttnn tensor to torch tensor + mesh_composer = ttnn.ConcatMesh2dToTensor( + mesh_device, dims=(3, 1) if model_args.is_galaxy else (1, -1), mesh_shape=model_args.cluster_shape + ) + tt_output_torch = ( + ttnn.to_torch(tt_out, mesh_composer=mesh_composer) + .permute(2, 1, 0, 3) + .squeeze(2)[: model_args.max_batch_size, 0:1, : model_args.vocab_size] + ) + + ttnn.deallocate(tt_out) + + if run_ref_pt: # Run reference model + # In this test all users have the same position + ref_output = reference_model(pt_decode_input, current_pos[0]) + + # Increment position + current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + # Append the generated token to the list of outputs /prefill + if i in range(len(encoded_prompts)): + # While in "prefill" mode, use the prompt tokens as the output + all_outputs.append(encoded_prompts[i]) # Update list of TT outputs + if run_ref_pt: + all_outputs_ref.append(encoded_prompts[i]) # Update list of ref outputs + + tt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) + if run_ref_pt: + pt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) + else: + # Greedy decode (temperature = 0) the generated token and save it to print out later + if run_ref_pt: + # Sample from reference model first + _, pt_out_tok = sample_host(ref_output, temperature=0, top_p=0.8) + pt_decode_input = embd(pt_out_tok) + all_outputs_ref.append(pt_out_tok.squeeze(1).tolist()[0]) + + # Use the same token for TT model (teacher forcing) + tt_decode_input = pt_decode_input + all_outputs.append(pt_out_tok.squeeze(1).tolist()[0]) + else: + # If not running reference model, sample from TT model directly + _, tt_out_tok = sample_host(tt_output_torch, temperature=0, top_p=0.8) + tt_decode_input = embd(tt_out_tok) + all_outputs.append(tt_out_tok.squeeze(1).tolist()[0]) + + # Measure PCC if also running reference model + if run_ref_pt: + passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc) + # Decode the output tokens back to text + decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in all_outputs] + logger.info(f"TTNN Decoded Outputs: {decoded_texts}") + decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in all_outputs_ref] + logger.info(f"Torch Decoded Outputs: {decoded_texts}") + + logger.info(comp_allclose(ref_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("Model Passed!") + else: + logger.warning("Model Failed!") + if not passing: + all_tests_pass = False + + if parse_chat: + conversation = parse_chat_output(tokenizer.decode(all_outputs).replace("\n", "\\n")) + display_chat(logger, conversation) + + if run_ref_pt: + if all_tests_pass: + logger.info(f"All {generation_length} decode iterations Passed!") + else: + logger.warning("One or more iterations of decode had bad PCC") + assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3_1b/tests/test_rmsnorm.py b/models/experimental/gemma3_1b/tests/test_rmsnorm.py new file mode 100644 index 000000000000..54168b9c7b94 --- /dev/null +++ b/models/experimental/gemma3_1b/tests/test_rmsnorm.py @@ -0,0 +1,116 @@ +""" Test for Gemma-3-1b-it RMSNorm """ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger + +import torch +import pytest +import os + +import ttnn +from models.experimental.gemma3_1b.tt.rmsnorm import RMSNorm +from models.tt_transformers.tt.distributed_norm import DistributedNorm + + +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_rms_norm() # Gemma3 RMSNorm + + state_dict_prefix = tt_model_args.get_state_dict_prefix("", 0) + first_layer_prefix = state_dict_prefix + "attention_norm." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_inner_norm = RMSNorm( + device=device, + dim=tt_model_args.dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="attention_norm", + weight_dtype=dtype, + is_distributed=tt_model_args.is_distributed_norm, + sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], + ) + + # Wrap it in DistributedNorm + tt_model = DistributedNorm(tt_inner_norm, tt_model_args, TG=tt_model_args.is_galaxy) + + input = torch.rand(1, 1, tt_model_args.dim) + + reference_output = reference_model(input) + + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), + memory_config=( + tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + ), + ) + + tt_output = tt_model(tt_input, mode=mode) + + # DistributedNorm outputs are replicated across devices + tt_output_torch = ttnn.to_torch( + tt_output, + mesh_composer=ttnn.ConcatMesh2dToTensor( + device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape + ), + )[:1, :, :] + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch[0]) + + logger.info(comp_allclose(reference_output, tt_output_torch[0])) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/experimental/gemma3_1b/tt/attention.py b/models/experimental/gemma3_1b/tt/attention.py index 47ba6a7d95fd..338e52dbec82 100644 --- a/models/experimental/gemma3_1b/tt/attention.py +++ b/models/experimental/gemma3_1b/tt/attention.py @@ -1,4 +1,13 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +""" +This is the attention implementation of the Gemma-3-1b-it + +We have re-used the Attention implementation of the TT-Transformers with few modifications. +This implementation has Changes in Datatype (Bfloat16) that supports the RMSNorm, +Sliding Window support. + +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -8,7 +17,8 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.common.rmsnorm import RMSNorm + +from models.experimental.gemma3_1b.tt.rmsnorm import RMSNorm from models.tt_transformers.tt.ccl import tt_all_gather, tt_all_reduce from models.tt_transformers.tt.model_config import OpGroup, TensorGroup @@ -27,6 +37,7 @@ def __init__( use_paged_kv_cache=False, ): super().__init__() + self.is_sliding = bool((layer_num + 1) % configuration.sliding_window_pattern) self.state_dict = state_dict self.mesh_device = mesh_device @@ -110,22 +121,22 @@ def __init__( decoder_id=layer_num, tensor=TensorGroup.KV_CACHE ) self.li_qkv_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_QKV_DECODE, configuration=configuration + decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration ) self.sdpa_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.SDPA_DECODE, configuration=configuration + decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration ) self.li_o_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_O_DECODE, configuration=configuration + decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration ) self.sdpa_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.SDPA_PREFILL, configuration=configuration + decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration ) self.li_qkv_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_QKV_PREFILL, configuration=configuration + decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration ) self.li_o_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_O_PREFILL, configuration=configuration + decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration ) layer_name = configuration.get_state_dict_prefix(self.__class__.__name__, layer_num) @@ -498,6 +509,7 @@ def forward_decode( # For example, a prompt w/ 1 user vs, the same prompt repeated N times for N users, will produce different outputs # This is because the SDPA op in decode mode has different number of reductions depending on batch size # Which leads to slightly different outputs from attention (due to accumulated errors) + q_heads_1BQD = ttnn.to_memory_config(q_heads_1BQD, ttnn.DRAM_MEMORY_CONFIG) if page_table: attn_output_1G4D = ttnn.transformer.paged_scaled_dot_product_attention_decode( q_heads_1BQD, @@ -584,7 +596,7 @@ def forward_decode( core_grid=ttnn.CoreGrid(y=4, x=8) if self.TG else None, program_config=self.model_config["ATTN_OUTPUT_PROGCFG"] if not self.TG else None, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b if self.TG else None, + dtype=ttnn.bfloat8_b if self.TG else ttnn.bfloat16, compute_kernel_config=self.li_o_decode_compute_kernel_cfg, ) @@ -772,7 +784,7 @@ def forward_prefill( ttnn.deallocate(v_fill) # SDPA - q_heads_1QSD_8b = ttnn.typecast(q_heads_1QSD, dtype=self.activation_dtype or ttnn.bfloat8_b) + q_heads_1QSD_8b = ttnn.typecast(q_heads_1QSD, dtype=self.activation_dtype or ttnn.bfloat16) ttnn.deallocate(q_heads_1QSD) if chunk_start_idx is not None: @@ -829,7 +841,7 @@ def forward_prefill( attn_output_11SH, self.wo, compute_kernel_config=self.li_o_prefill_compute_kernel_cfg, - dtype=self.activation_dtype or ttnn.bfloat8_b, + dtype=self.activation_dtype or ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, program_config=self.model_config["WO_PREFILL_PROGCFG"](seq_len), ) diff --git a/models/experimental/gemma3_1b/tt/decoder.py b/models/experimental/gemma3_1b/tt/decoder.py new file mode 100644 index 000000000000..2cf46643c599 --- /dev/null +++ b/models/experimental/gemma3_1b/tt/decoder.py @@ -0,0 +1,226 @@ +""" + +This is the Decoder block for the Gemma 3-1b-it model +We couldn't use the existing implementation in TT-Transformers because the usage of submodules is different + +In Gemma-3-1b-it, The decoder Block has Additional pre_feedforward_layernorm and post_feedforward_layernorm, +And the logic of implementation is different from the existing implementation in TT-Transformers. + +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn + +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.distributed_norm import DistributedNorm +from models.experimental.gemma3_1b.tt.rmsnorm import RMSNorm + +from models.experimental.gemma3_1b.tt.attention import Attention + +from models.experimental.gemma3_1b.tt.mlp import MLP +from models.tt_transformers.tt.model_config import TensorGroup + + +class TransformerBlock(LightweightModule): + def __init__( + self, + args, + mesh_device, + dtype, + state_dict, + layer_num, + weight_cache_path, + transformation_mats, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + + self.args = args + self.hidden_size = args.dim + self.n_heads = args.n_heads + self.head_dim = self.hidden_size // self.n_heads + self.max_seq_len = args.max_seq_len + self.dim = args.dim + self.max_batch_size = args.max_batch_size + self.n_kv_heads = args.n_kv_heads + self.current = 0 + self.model_config = args.get_model_config() + + self.layer_num = layer_num + + self.attention = Attention( + mesh_device=mesh_device, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + layer_num=layer_num, + dtype=dtype, + transformation_mats=transformation_mats, + configuration=args, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + + self.feed_forward = MLP( + mesh_device=mesh_device, + args=args, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + layer_num=layer_num, + dtype=dtype, + model_config=self.model_config, + ) + + self.attention_norm = DistributedNorm( # input_layernorm + RMSNorm( + device=mesh_device, + dim=args.dim, + eps=args.norm_eps, + state_dict=state_dict, + state_dict_prefix=args.get_state_dict_prefix("", layer_num), + weight_cache_path=None if args.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key="attention_norm", + is_distributed=self.args.is_distributed_norm, + sharded_program_config=self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], + sharded_output_config=self.model_config["SHARDED_ATTN_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), + ), + args, + TG=args.is_galaxy, + ) + + self.ff_norm = DistributedNorm( # post_attention_layernorm + RMSNorm( + device=mesh_device, + dim=args.dim, + eps=args.norm_eps, + state_dict=state_dict, + state_dict_prefix=args.get_state_dict_prefix("", layer_num), + weight_cache_path=None if args.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key="ffn_norm", + is_distributed=self.args.is_distributed_norm, + sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], + sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), + ), + args, + TG=args.is_galaxy, + ) + + self.pre_ff_norm = DistributedNorm( # pre_feedforward_layernorm + RMSNorm( + device=mesh_device, + dim=args.dim, + eps=args.norm_eps, + state_dict=state_dict, + state_dict_prefix=args.get_state_dict_prefix("", layer_num), + weight_cache_path=None if args.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key="pre_feedforward_layernorm", + is_distributed=self.args.is_distributed_norm, + sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], + sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), + ), + args, + TG=args.is_galaxy, + ) + + self.post_ff_norm = DistributedNorm( # post_feedforward_layernorm + RMSNorm( + device=mesh_device, + dim=args.dim, + eps=args.norm_eps, + state_dict=state_dict, + state_dict_prefix=args.get_state_dict_prefix("", layer_num), + weight_cache_path=None if args.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key="post_feedforward_layernorm", + is_distributed=self.args.is_distributed_norm, + sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], + sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), + ), + args, + TG=args.is_galaxy, + ) + + def forward( + self, + hidden_states: ttnn.Tensor, + current_pos, + rot_mats=None, + user_id=0, + mode="decode", + page_table=None, + chunk_page_table=None, + chunk_start_idx=None, + kv_cache=None, + ): + TG = self.args.is_galaxy + skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + + assert ( + hidden_states.memory_config() == skip_mem_cfg + ), f"decoder input memcfg mismatch: {hidden_states.memory_config()} != {skip_mem_cfg}" + residual = hidden_states + + attn_in = self.attention_norm(hidden_states, mode) + + if self.attention.is_sliding: + position_embeddings = rot_mats[1] + else: + position_embeddings = rot_mats[0] + + attn_out = self.attention.forward( + attn_in, + current_pos, + position_embeddings, + user_id, + mode, + page_table=page_table, + chunk_page_table=chunk_page_table, + chunk_start_idx=chunk_start_idx, + kv_cache=kv_cache, + ) + + hidden_states = self.ff_norm(attn_out, mode) + + ttnn.deallocate(attn_out) + ttnn.deallocate(attn_in) + + hidden_states = ttnn.add(hidden_states, residual, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16) + + residual = hidden_states + + hidden_states = self.pre_ff_norm(hidden_states, mode) + + if TG and mode == "decode": + hidden_states = ttnn.to_memory_config(hidden_states, memory_config=self.model_config["MLP_ACT_MEMCFG"]) + + hidden_states = self.feed_forward.forward(hidden_states, mode) + + activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( + decoder_id=self.layer_num, tensor=TensorGroup.ACTIVATION + ) + + hidden_states = self.post_ff_norm(hidden_states, mode) + + hidden_states = ttnn.add( + hidden_states, + residual, + memory_config=skip_mem_cfg, + dtype=self.args.ccl_dtype + if TG and not self.args.is_distributed_norm(mode) + else activation_dtype or ttnn.bfloat16, + ) + + return hidden_states diff --git a/models/experimental/gemma3_1b/tt/lm_head.py b/models/experimental/gemma3_1b/tt/lm_head.py index 3be020957904..5a62229111c4 100644 --- a/models/experimental/gemma3_1b/tt/lm_head.py +++ b/models/experimental/gemma3_1b/tt/lm_head.py @@ -1,4 +1,11 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +""" +This is the implementation of lm_head of the Gemma-3-1b-it. + +We have re-used the lm_head implementation of the TT-Transformers library along with few modifications. +This implementation has changes in Memory Configurations (DRAM Memory Config) and Data Type (bfloat16). +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -99,11 +106,12 @@ def __init__( ) self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi2, - math_approx_mode=False, - fp32_dest_acc_en=False, + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, packer_l1_acc=True, + dst_full_sync_en=False, ) + if args.is_galaxy: self.program_configs = [ ( @@ -138,12 +146,12 @@ def forward(self, x: ttnn.Tensor): compute_kernel_config=self.compute_kernel_config, program_config=pc, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat16, ) - outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.L1_MEMORY_CONFIG)) + outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.DRAM_MEMORY_CONFIG)) # Concatenate the outputs - output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.DRAM_MEMORY_CONFIG) output = tt_all_reduce( output, diff --git a/models/experimental/gemma3_1b/tt/mlp.py b/models/experimental/gemma3_1b/tt/mlp.py index 9893ec2440e4..2abd227fb2ca 100644 --- a/models/experimental/gemma3_1b/tt/mlp.py +++ b/models/experimental/gemma3_1b/tt/mlp.py @@ -1,4 +1,12 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +""" +This is the implementation of MLP (feed-forward) submodule of Gemma-3-1b-it. + +We have re-used the MLP implementation of the TT-Transformers library with few modifications. +This implementation has changes in Data Type (bfloat16). +""" + + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -72,7 +80,9 @@ def __init__( self.w3 = as_sharded_tensor("w3_sharded", ff1_3_dtype, dims=w1_dims) # Default activation is SILU - self.activation_type = self.args.mlp_activation_type + self.activation_type = ( + args.mlp_activation_type if hasattr(args, "mlp_activation_type") else ttnn.UnaryOpType.SILU + ) def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: """ @@ -88,7 +98,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: decoder_id=layer_num, tensor=TensorGroup.ACTIVATION ) li_ff1_3_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_FF1_FF3, configuration=self.args + decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=self.args ) if mode == "decode": # Sharded config @@ -182,7 +192,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: w1_out, w3_out, input_tensor_a_activations=[self.activation_type], - dtype=activation_dtype or ttnn.bfloat8_b, + dtype=activation_dtype or ttnn.bfloat16, memory_config=w1_out.memory_config(), ) @@ -207,7 +217,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG) li_ff2_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.LI_FF2, configuration=self.args + decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=self.args ) w2_out = ttnn.linear( w2_in, diff --git a/models/experimental/gemma3_1b/tt/model.py b/models/experimental/gemma3_1b/tt/model.py new file mode 100644 index 000000000000..d59adb37de3c --- /dev/null +++ b/models/experimental/gemma3_1b/tt/model.py @@ -0,0 +1,432 @@ +""" + +This is the end-to-end implementation of the Gemma-3-1b-it model. + +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from tqdm import tqdm +import torch + +from models.experimental.gemma3_1b.tt.rmsnorm import RMSNorm + +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.embedding import Embedding +from models.tt_transformers.tt.rope import RotarySetup + +from models.experimental.gemma3_1b.tt.decoder import TransformerBlock +from models.tt_transformers.tt.distributed_norm import DistributedNorm +from models.experimental.gemma3_1b.tt.lm_head import LMHead +from models.tt_transformers.tt.model_config import TensorGroup +from models.tt_transformers.tt.common import copy_host_to_device + + +class Gemma3Transformer(LightweightModule): + def __init__( + self, + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + assert self.vocab_size > 0 + self.n_layers = args.n_layers + self.mesh_device = mesh_device + self.dtype = dtype + self.model_config = args.get_model_config() + self.grid_size = self.args.max_grid_size + state_dict_prefix = args.get_state_dict_prefix("", None) + + self.embd = Embedding( + mesh_device=mesh_device, + args=args, + weight_cache_path=args.weight_cache_path(dtype), + state_dict=state_dict, + dtype=ttnn.bfloat16, # Row major layout requires bfloat16 + ) + + self.rope_setup = RotarySetup( + mesh_device, + args.max_batch_size, + args.head_dim, + args.max_seq_len, + args.rope_theta, + args.rope_scaling_factor, + args.orig_context_len, + ) + + self.rope_setup_local = RotarySetup( + mesh_device, + args.max_batch_size, + args.head_dim, + args.max_seq_len, + 10000, # Rope theta local + None, # Rope Scaling Factor + args.orig_context_len, + ) + + self.trans_mats_dict = self.rope_setup.get_both_trans_mats() + + self.layers = [ + TransformerBlock( + args=args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + layer_num=i, + transformation_mats=self.trans_mats_dict, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + for i in tqdm(range(self.n_layers)) + ] + self.norm = DistributedNorm( + RMSNorm( + device=mesh_device, + dim=args.dim, + eps=args.norm_eps, + state_dict=state_dict, + state_dict_prefix=args.get_state_dict_prefix("", None), + weight_cache_path=None if args.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key="norm", + add_unit_offset=True, + is_distributed=self.args.is_distributed_norm, + sharded_program_config=self.model_config["SHARDED_NORM_LM_HEAD_PRGM_CFG"], + sharded_output_config=self.model_config["LM_HEAD_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), + ), + args, + args.is_galaxy, + ) + + self.lm_head = LMHead( + args=args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_cache_path=weight_cache_path, + max_columns_per_device=self.args.max_columns_per_device_lm_head, + ) + + self.embed_scale = args.dim**0.5 + + def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + tokens = tokens.reshape(1, 1, 1, -1) + S = tokens.shape[-1] + tokens = ttnn.from_torch( + tokens, + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + tokens_embd = self.embd(tokens) + tokens_embd = ttnn.multiply( + tokens_embd, self.embed_scale + ) # TODO In UT, Without Multiply we got passing with better PCC, Lets debug this in pipeline + + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + + # Slice the rot mats to the prefill seqlen + assert ( + self.rope_setup.cos_matrix.shape[2] >= start_pos + S + ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" + + tt_rot_mats_prefill_global = [ + self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + tt_rot_mats_prefill_local = [ + self.rope_setup_local.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup_local.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + if chunk_page_table is not None: + tt_chunk_page_table = ttnn.from_torch( + chunk_page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_chunk_page_table = None + + return tokens_embd, [tt_rot_mats_prefill_global, tt_rot_mats_prefill_local], tt_page_table, tt_chunk_page_table + + def prepare_inputs_decode(self, *inputs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + Its implementation can take advantage of a few other functions which the + model must implement. + """ + host_inputs = self.prepare_decode_inputs_host(*inputs) + device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device) # Helper function + transformed_device_inputs = self.transform_decode_inputs_device(*device_inputs) + return transformed_device_inputs + + def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): + """ + Inputs are torch tensors or python types. Outputs are ttnn tensors on host. + NOTE: Tokens and current_pos are padded to batch + """ + B = tokens.shape[0] + assert current_pos.shape[0] == B, "Batch size mismatch" + assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size" + + # Necessary padding to be full tile sized when on device + tokens = torch.nn.functional.pad(tokens.view(-1), (0, 32 - len(tokens)), "constant", 0) + tokens = ttnn.from_torch( + tokens, + device=None, + dtype=ttnn.uint32, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + tokens = ttnn.unsqueeze_to_4D(tokens) + + rot_current_pos = torch.maximum( + current_pos, torch.tensor(0, dtype=torch.int64) + ) # Ensure position indices are non-negative + rope_idxs = self.rope_setup.get_rot_idxs(rot_current_pos, on_host=True) + current_pos_tt = ttnn.from_torch( + current_pos, + device=None, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + self.mesh_device, + dims=(None, 0) if (self.args.is_galaxy and B > 1) else (None, None), + mesh_shape=self.args.cluster_shape, + ), + ) + + if page_table is not None: + page_table = ttnn.from_torch( + page_table, + device=None, + dtype=ttnn.int32, + mesh_mapper=ttnn.ShardTensor2dMesh( + self.mesh_device, + dims=(None, -2) if (self.args.is_galaxy and B > 1) else (None, None), + mesh_shape=self.args.cluster_shape, + ), + ) + return tokens, current_pos_tt, rope_idxs, page_table + + def transform_decode_inputs_device(self, tokens, current_pos, rope_idxs, page_table=None): + """ + Inputs are ttnn tensors on device. This function applies any on-device + transformations which should happen before forward decode. + For example: tilize, reshape, shard. + Return transformed device tensors + + Get rope sin/cos + Embed tokens + """ + tt_rot_mats = self.rope_setup.get_rot_mats(rope_idxs) + tt_rot_mats_local = self.rope_setup_local.get_rot_mats(rope_idxs) + + tt_tokens = self.embd(tokens) + tt_tokens = ttnn.multiply(tt_tokens, self.embed_scale) + + tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens) + tt_tokens = ttnn.to_memory_config( + tt_tokens, + self.args.model_config["DECODE_RESIDUAL_MEMCFG"], + ) + return tt_tokens, current_pos, [tt_rot_mats, tt_rot_mats_local], page_table + + def process_output_prefill(self, tt_out, last_token_idx): + """ + Input is ttnn device tensor of logits. Output is torch logits tensor. + NOTE: In this model, prefill always uses get_last_token + """ + logits = ttnn.to_torch( + tt_out, + mesh_composer=ttnn.ConcatMesh2dToTensor( + self.mesh_device, dims=(3, 1) if self.args.is_galaxy else (1, -1), mesh_shape=self.args.cluster_shape + ), + )[0, 0, last_token_idx, : self.vocab_size] + return logits + + def process_output_decode(self, tt_out, B, S=1, is_tokens=False): + """ + Input is ttnn device tensor of logits if is_tokens=False, otherwise tokens. Output is the corresponding torch tensor. + """ + if is_tokens: + tt_out = ttnn.to_torch( + tt_out, # tt_out.cpu(blocking=True, cq_id=1), + mesh_composer=ttnn.ConcatMesh2dToTensor( + self.mesh_device, + dims=(3, 1) if self.args.is_galaxy else (1, -1), + mesh_shape=self.args.cluster_shape, + ), + )[0, 0, :B, 0] + return tt_out + + if self.args.num_devices > 1: + tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() + else: + tt_out = ttnn.to_torch(tt_out).float() + tt_out = tt_out[:, :, :B, : self.vocab_size].view(B, S, -1) + return tt_out + + def ttnn_prefill_forward( + self, + x, + rot_mats, + user_id, + page_table=None, + chunk_page_table=None, + chunk_start_idx=None, + get_last_token=-1, + kv_cache=None, + ): + """ + This method will take device tensors and any other args to run forward. + It returns ttnn device tensors. + """ + return self.forward( + x, + current_pos=None, + rot_mats=rot_mats, + user_id=user_id, + mode="prefill", + page_table=page_table, + chunk_page_table=chunk_page_table, + chunk_start_idx=chunk_start_idx, + get_last_token=get_last_token, + kv_cache=kv_cache, + ) + + def ttnn_decode_forward( + self, + x, + current_pos, + rot_mats, + page_table=None, + kv_cache=None, + argmax_on_device=False, + ): + """ + This method will take device tensors and any other args to run forward. + It returns ttnn device tensors. + """ + + tt_logits = self.forward( + x, + current_pos, + rot_mats=rot_mats, + mode="decode", + page_table=page_table, + kv_cache=kv_cache, + ) + + # Gather the output across all devices and untilize the tensor (for argmax) + if self.args.num_devices > 1: + if self.args.is_galaxy: + tt_logits = ttnn.all_gather( + tt_logits, + dim=3, + num_links=2, + cluster_axis=0, + mesh_device=self.mesh_device, + topology=self.args.ccl_topology(), + ) + else: + tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=self.args.ccl_topology()) + tt_logits = ttnn.untilize(tt_logits, use_multicore=True) + + if argmax_on_device: + tt_logits = ttnn.argmax(tt_logits, dim=3, keepdim=True, use_multicore=True) + else: + # Send output logits to DRAM so L1 is not reserved for ttnn tracing and can be used by subsequent operations + if not self.args.is_galaxy: + tt_logits = ttnn.to_memory_config(tt_logits, ttnn.DRAM_MEMORY_CONFIG) + + return tt_logits + + def forward( + self, + x: ttnn.Tensor, + current_pos, + rot_mats=None, + user_id=0, + mode="decode", + page_table=None, + chunk_page_table=None, + chunk_start_idx=None, + get_last_token=-1, + kv_cache=None, + ): + for i, layer in enumerate(self.layers): + # No-op if callers already provide the right memory config + activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( + decoder_id=i, tensor=TensorGroup.ACTIVATION + ) + if mode == "decode" and not self.args.is_galaxy: + x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"], activation_dtype) + elif activation_dtype is not None and x.dtype != activation_dtype: + x = ttnn.typecast(x, activation_dtype) + + x = layer( + x, + current_pos, + rot_mats, + user_id, + mode, + page_table, + chunk_page_table=chunk_page_table, + chunk_start_idx=chunk_start_idx, + kv_cache=kv_cache[i] if kv_cache is not None else None, + ) + + if mode == "prefill" and get_last_token == -1: + return x + + # Slicing the tensor to the nearest ceiling/floor multiples of 32 for the prefill_len, to get the last token + if get_last_token != -1: + x = ttnn.slice(x, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, x.shape[-1])) + + # Output norm + x = self.norm(x, mode=mode) + + if mode == "prefill" and self.model_config["LM_HEAD_INPUT_MEMCFG"].is_sharded(): + x = ttnn.interleaved_to_sharded(x, self.model_config["LM_HEAD_INPUT_MEMCFG"]) + + x = self.lm_head(x) + + if mode == "prefill": + x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.to_memory_config(x, memory_config=ttnn.DRAM_MEMORY_CONFIG) + return x diff --git a/models/experimental/gemma3_1b/tt/rmsnorm.py b/models/experimental/gemma3_1b/tt/rmsnorm.py new file mode 100644 index 000000000000..2c3f9dabce6d --- /dev/null +++ b/models/experimental/gemma3_1b/tt/rmsnorm.py @@ -0,0 +1,177 @@ +""" +This is the modified version of the RMSNorm for Gemma-3-1b-it model. + +We have modified the RMSNorm implementation equivalent to RMSNorm in Gemma-3-1b-it. +We have handled the unit offset addition in the RMSNorm implementation directly into the TTNN Weights + +""" + +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + """ + RMSNorm supporting replication over a MeshDevice and sharding within devices. + + This class implements a Root Mean Square Normalization (RMSNorm) that can be + distributed across multiple devices and cores. If the `device` parameter is a + MeshDevice, the weights and computations are replicated across all devices in + the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. + + Args: + device: The device or MeshDevice on which to perform the computations. + state_dict: The state dictionary containing the model parameters. + dim: Input dimension (e.g. model hidden dimension size). + layer_num: The layer number to determine the weight key in the state dictionary. + weight_key: The key for retrieving the weight from the state dictionary. + weight_cache_path: Optional path for caching the tilized weights. + weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. + weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. + model_config: Optional configuration dictionary for the model. + eps (float): Small value to avoid division by zero in normalization, default is 1e-05. + + If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG + and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. + """ + + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-06, + add_unit_offset=True, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + ): + super().__init__() + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + + # # Add offset before caching + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + if add_unit_offset: + self.weight = ttnn.add(self.weight, 1.0) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + if add_unit_offset: + self.weight_distributed = ttnn.add(self.weight_distributed, 1.0) # Add offset to distributed weight + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = self._distributed_rmsnorm + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _distributed_rmsnorm( + self, inp, epsilon=1e-6, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + """ + TODO: We are using Primitive RMSNorm. + This will be replaced once the ttnn.rms_norm atol issue is fixed. + issue: https://github.com/tenstorrent/tt-metal/issues/25883 + """ + inp = ttnn.sharded_to_interleaved(inp) + + xnorm = ttnn.pow(inp, 2) + + xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) + + xnorm = ttnn.rsqrt(xnorm + epsilon) + + xnorm = ttnn.multiply(inp, xnorm) + + weight = ttnn.reshape(weight, [1, 1, 1, -1]) + + output = ttnn.multiply(xnorm, weight) + + if memory_config is not None: + output = ttnn.to_memory_config(output, memory_config) + + return output diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 5eebf47ce735..17655ea39770 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +import os import re from enum import Enum from typing import Optional @@ -216,9 +217,8 @@ def encode_prompt_hf(tokenizer, prompt_text, system_prompt_text=None): return tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True) -def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): - # FIXME: Llama-3.x specific scaling - we need to support yarn for Qwen2.5 models - # Values obtained from grid search +def compute_llama3_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + """Llama-3.x specific scaling for rotary embeddings.""" low_freq_factor = 1 high_freq_factor = 4 @@ -238,6 +238,24 @@ def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: in return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) +def compute_default_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + """Default scaling for rotary embeddings.""" + return freqs + + +def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + # FIXME: Llama-3.x specific scaling - we need to support yarn for Qwen2.5 models + + hf_model_env = os.getenv("HF_MODEL") + + if hf_model_env == "google/gemma-3-1b-it": + freqs = compute_default_parameters(freqs, scale_factor, orig_context_len) + elif "LLAMA_DIR" in os.environ or (hf_model_env and "llama" in hf_model_env.lower()): + freqs = compute_llama3_parameters(freqs, scale_factor, orig_context_len) + + return freqs + + def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len): """ Precompute the frequency tensor for sine and cosine values with given dimensions. @@ -585,7 +603,11 @@ def create_tt_model( state_dict=None, num_layers=None, ): - from models.tt_transformers.tt.model import Transformer + if "HF_MODEL" in os.environ and "gemma-3" in os.environ["HF_MODEL"].lower(): + from models.experimental.gemma3_1b.tt.model import Gemma3Transformer as Transformer + else: + from models.tt_transformers.tt.model import Transformer + from models.tt_transformers.tt.model_config import ModelArgs tt_model_args = ModelArgs( diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index 355b1c3af2be..5f9c181c57df 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -161,10 +161,8 @@ def map_hf_to_meta_keys(loaded_weights): "model.layers.{layer}.mlp.gate_proj.weight": "layers.{layer}.feed_forward.w1.weight", "model.layers.{layer}.mlp.up_proj.weight": "layers.{layer}.feed_forward.w3.weight", "model.layers.{layer}.mlp.down_proj.weight": "layers.{layer}.feed_forward.w2.weight", - # Full path MLP bias mappings - "model.layers.{layer}.mlp.gate_proj.bias": "layers.{layer}.feed_forward.w1.bias", - "model.layers.{layer}.mlp.up_proj.bias": "layers.{layer}.feed_forward.w3.bias", - "model.layers.{layer}.mlp.down_proj.bias": "layers.{layer}.feed_forward.w2.bias", + "model.layers.{layer}.pre_feedforward_layernorm.weight": "layers.{layer}.pre_feedforward_layernorm.weight", + "model.layers.{layer}.post_feedforward_layernorm.weight": "layers.{layer}.post_feedforward_layernorm.weight", } meta_state_dict = {} diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index f6be1f2172a4..659f729f28c6 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -69,6 +69,7 @@ class OpGroup(Enum): LI_QKV_PREFILL = "li_qkv_prefill" LI_O_PREFILL = "li_o_prefill" SDPA_PREFILL = "sdpa_prefill" + ACCURACY = "accuracy" class MathFidelitySetting(Enum): @@ -77,6 +78,7 @@ class MathFidelitySetting(Enum): HIFI2_NA = "hifi2na" # na specified `packer_l1_acc=False` and `fp32_dest_acc_en=False` in compute kernel config HIFI2_FP16 = "hifi2fp16" # fp16 specified `fp32_dest_acc_en=False` in compute kernel config HIFI4 = "hifi4" + HIFI4_FP32 = "hifi4fp32" class ModelOptimizations: @@ -233,7 +235,7 @@ def _default_settings(self): # Attention TensorGroup.WQKV: PrecisionSetting.BFP8, TensorGroup.WO: PrecisionSetting.BFP8, - TensorGroup.KV_CACHE: PrecisionSetting.BFP8, + TensorGroup.KV_CACHE: PrecisionSetting.BF16, # Upgraded from BFP8 to prevent accumulation errors # Activation across whole model TensorGroup.ACTIVATION: None, # this signals that original dtype should be used }, @@ -243,11 +245,12 @@ def _default_settings(self): OpGroup.LI_FF2: MathFidelitySetting.HIFI2_FP16, # Attention operators -- linear and scaled_dot_product_attention, in decode and prefill modes OpGroup.LI_QKV_DECODE: MathFidelitySetting.HIFI2, - OpGroup.SDPA_DECODE: MathFidelitySetting.HIFI2, + OpGroup.SDPA_DECODE: MathFidelitySetting.HIFI4, # Upgraded from HIFI2 for better precision OpGroup.LI_O_DECODE: MathFidelitySetting.HIFI2, OpGroup.LI_QKV_PREFILL: MathFidelitySetting.HIFI2, OpGroup.SDPA_PREFILL: MathFidelitySetting.HIFI4, OpGroup.LI_O_PREFILL: MathFidelitySetting.HIFI2, # FP32 accumulate is important here + OpGroup.ACCURACY: MathFidelitySetting.HIFI4_FP32, }, } @@ -561,6 +564,7 @@ def __init__( "Qwen2.5-VL-32B": {"N150": None, "N300": None, "T3K": 64, "TG": None, "P150x4": None}, "Qwen2.5-VL-72B": {"N150": None, "N300": None, "T3K": 32, "TG": None, "P150x4": None}, "Phi-3.5-mini-instruct": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, + "gemma-3-1b-it": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "QwQ-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, "Qwen3-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, } @@ -581,9 +585,10 @@ def __init__( max_prefill_chunk_size_div1024 = int(max_prefill_chunk_size_div1024) self.max_prefill_chunk_size = max_prefill_chunk_size_div1024 * 1024 - if (self.base_model_name in ["Llama-3.1-8B", "Llama-3.2-11B", "Mistral-7B"] and self.device_name == "N150") or ( - self.base_model_name in ["Qwen2.5-7B"] and self.device_name == "N300" - ): + if ( + self.base_model_name in ["Llama-3.1-8B", "Llama-3.2-11B", "Mistral-7B", "gemma-3-1b-it"] + and self.device_name == "N150" + ) or (self.base_model_name in ["Qwen2.5-7B"] and self.device_name == "N300"): logger.info(f"Reducing prefill_len_cutoff to 512 for {self.model_name} on {self.device_name}") self.prefill_len_cutoff = 512 @@ -658,6 +663,12 @@ def __init__( fp32_dest_acc_en=True, packer_l1_acc=True, ) + self.compute_kernel_config_hifi4_en_fp32 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) self.compute_kernel_config_hifi2_na = ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi2, math_approx_mode=False, @@ -803,7 +814,7 @@ def __init__( k=k_dim, n=n_dim, grid_size=self.find_prefill_grid(prefill_rows, k_dim // self.tile_size), - in0_block_w=1 if self.is_galaxy else None, + in0_block_w=32 if "gemma-3" in self.model_name else 1 if self.is_galaxy else None, fuse_batch=seq_len <= 1024, per_core_N=math.ceil(n_dim / (self.tile_size * dram_shard_grid_width)) if dram_sharded_wo else None, ) @@ -1257,7 +1268,7 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len): self.num_all_gather_links = ( 2 if self.is_galaxy else 1 ) # TODO: try out 3 for short axis and 4 for long axis (TG only) <- should work but untested in model - self.ccl_dtype = ttnn.bfloat8_b + self.ccl_dtype = ttnn.bfloat16 if "gemma-3" in self.model_name else ttnn.bfloat8_b logger.info(f"Attention grid: {attn_input_grid}") logger.info(f"MLP grid: {mlp_core_grid}") @@ -1396,6 +1407,14 @@ def _set_model_specific_params(self): def _set_params_from_dict(self, config, is_hf=False): # Try to get text_config, if it doesn't exist everything is text config + eos_token_id = config.get("eos_token_id", None) + + self.eos_token_id = ( + None if isinstance(eos_token_id, int) else eos_token_id + ) # Gemma like models can have a list of eos token ids + + self.sliding_window_pattern = config.get("sliding_window_pattern", 1) + text_config = config.get("text_config", config) # Common params with different names between Meta and HF @@ -2103,7 +2122,7 @@ def create_tokenizer(self): # Add meta-compatible stop token list to the HF tokenizer if not "stop_tokens" in tokenizer.__dict__: - tokenizer.stop_tokens = [tokenizer.eos_token_id] + tokenizer.stop_tokens = self.eos_token_id if self.eos_token_id is not None else [tokenizer.eos_token_id] return tokenizer def encode_prompt(self, prompt_text, system_prompt_text=None, instruct=True): @@ -2227,7 +2246,13 @@ def reference_decoder(self): else: model = self.reference_transformer(wrap=False) layer = model.model.layers[0] - wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb) + rotary_emb = model.model.rotary_emb + + if "gemma-3" in self.model_name: + wrapper = HfGemmaDecoderWrapper(layer, self.head_dim, rotary_emb) + else: + wrapper = HfDecoderWrapper(layer, self.head_dim, rotary_emb) + return wrapper def reference_attention(self): @@ -2238,7 +2263,12 @@ def reference_attention(self): else: model = self.reference_transformer(wrap=False) layer = model.model.layers[0].self_attn - use_position_embeddings = layer.__class__.__name__ in ("Qwen3Attention", "MistralAttention") + use_position_embeddings = layer.__class__.__name__ in ( + "Qwen3Attention", + "MistralAttention", + "Gemma3Attention", + ) + wrapper = HfAttentionWrapper( layer, self.head_dim, model.model.rotary_emb if use_position_embeddings else None ) @@ -2402,7 +2432,12 @@ def __init__(self, decoder, head_dim, rotary_emb): def forward(self, x, start_pos, freqs_cis_i, mask=None): position_ids = torch.tensor([list(range(start_pos, start_pos + x.shape[1]))] * x.shape[0]) - position_embeddings = self.rotary_emb(x, position_ids) + # TODO: Generalize for other HF models + model_name_env = os.getenv("HF_MODEL") + if model_name_env is not None and "mistral" in model_name_env.lower(): + position_embeddings = self.rotary_emb(x, x.shape[1]) + else: + position_embeddings = self.rotary_emb(x, position_ids) if mask is not None: while len(mask.shape) < 4: @@ -2425,6 +2460,46 @@ def load_state_dict(self, state_dict): return self.decoder.load_state_dict(convert_meta_to_hf(state_dict, self.head_dim)) +class HfGemmaDecoderWrapper: + def __init__(self, decoder, head_dim, rotary_emb): + from transformers import DynamicCache + + self.decoder = decoder + self.head_dim = head_dim + self.rotary_emb = rotary_emb + self.past_key_values = DynamicCache() + + def forward(self, x, start_pos, freqs_cis_i, mask=None): + position_ids = torch.tensor([list(range(start_pos, start_pos + x.shape[1]))] * x.shape[0]) + # TODO: Generalize for other HF models + model_name_env = os.getenv("HF_MODEL") + if model_name_env is not None and "mistral" in model_name_env.lower(): + position_embeddings = self.rotary_emb(x, x.shape[1]) + else: + position_embeddings = self.rotary_emb(x, position_ids) + + if mask is not None: + while len(mask.shape) < 4: + mask = mask.unsqueeze(0) + result = self.decoder.forward( + x, + position_embeddings_global=position_embeddings, + position_embeddings_local=position_embeddings, + past_key_value=self.past_key_values, + use_cache=True, + position_ids=position_ids, + attention_mask=mask, + ) + output = result[0] + return output + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_state_dict(self, state_dict): + return self.decoder.load_state_dict(convert_meta_to_hf(state_dict, self.head_dim)) + + class HfModelWrapper: def __init__(self, model, head_dim): from transformers import DynamicCache @@ -2520,6 +2595,7 @@ def get_math_fidelity(self, decoder_id, op: OpGroup, configuration: ModelArgs): MathFidelitySetting.HIFI2_NA: configuration.compute_kernel_config_hifi2_na, MathFidelitySetting.HIFI2_FP16: configuration.compute_kernel_config_hifi2_fp16, MathFidelitySetting.HIFI4: configuration.compute_kernel_config_hifi4, + MathFidelitySetting.HIFI4_FP32: configuration.compute_kernel_config_hifi4_en_fp32, } return math_fidelity_setting_lookup[self.decoder_optimizations[decoder_id].op_fidelity_settings[op]] From d742f83bee5b095d851bc780c9eb93d6f3dc9d33 Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Tue, 5 Aug 2025 09:46:43 +0000 Subject: [PATCH 3/4] Rebase Rotary Setup --- models/experimental/gemma3_1b/tt/model.py | 37 +++++++++++------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/models/experimental/gemma3_1b/tt/model.py b/models/experimental/gemma3_1b/tt/model.py index d59adb37de3c..479ba6a6b729 100644 --- a/models/experimental/gemma3_1b/tt/model.py +++ b/models/experimental/gemma3_1b/tt/model.py @@ -35,6 +35,7 @@ def __init__( weight_cache_path, paged_attention_config=None, use_paged_kv_cache=False, + rope_setup_class=None, ): super().__init__() self.args = args @@ -54,25 +55,23 @@ def __init__( state_dict=state_dict, dtype=ttnn.bfloat16, # Row major layout requires bfloat16 ) - - self.rope_setup = RotarySetup( - mesh_device, - args.max_batch_size, - args.head_dim, - args.max_seq_len, - args.rope_theta, - args.rope_scaling_factor, - args.orig_context_len, + ActualRopeSetupClass = rope_setup_class if rope_setup_class is not None else RotarySetup + self.rope_setup = ActualRopeSetupClass( + device=mesh_device, + batch_size=args.max_batch_size, + head_dim=args.head_dim, + max_seq_len=args.max_seq_len, + rope_theta=args.rope_theta, + rope_scaling=args.rope_scaling, ) - self.rope_setup_local = RotarySetup( - mesh_device, - args.max_batch_size, - args.head_dim, - args.max_seq_len, - 10000, # Rope theta local - None, # Rope Scaling Factor - args.orig_context_len, + self.rope_setup_local = ActualRopeSetupClass( + device=mesh_device, + batch_size=args.max_batch_size, + head_dim=args.head_dim, + max_seq_len=args.max_seq_len, + rope_theta=10000, # Rope theta local + rope_scaling=None, # Rope Scaling Factor ) self.trans_mats_dict = self.rope_setup.get_both_trans_mats() @@ -139,9 +138,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) tokens_embd = self.embd(tokens) - tokens_embd = ttnn.multiply( - tokens_embd, self.embed_scale - ) # TODO In UT, Without Multiply we got passing with better PCC, Lets debug this in pipeline + tokens_embd = ttnn.multiply(tokens_embd, self.embed_scale) tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) From 4eda471bfe829fa472e6a4404abfe813339b92ea Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Tue, 5 Aug 2025 09:47:39 +0000 Subject: [PATCH 4/4] Fix RMSNorm unit offset --- models/experimental/gemma3_1b/tt/rmsnorm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/models/experimental/gemma3_1b/tt/rmsnorm.py b/models/experimental/gemma3_1b/tt/rmsnorm.py index 2c3f9dabce6d..f796aed86701 100644 --- a/models/experimental/gemma3_1b/tt/rmsnorm.py +++ b/models/experimental/gemma3_1b/tt/rmsnorm.py @@ -77,6 +77,8 @@ def __init__( torch_weight = ( state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) ) + if add_unit_offset: + torch_weight = torch_weight + 1.0 # # Add offset before caching cache_name = None if weight_cache_path is None else weight_cache_path / weight_name @@ -93,9 +95,6 @@ def __init__( cache_file_name=cache_name, mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, ) - if add_unit_offset: - self.weight = ttnn.add(self.weight, 1.0) - if self.is_distributed: self.weight_distributed = ttnn.as_tensor( torch_weight, @@ -108,9 +107,6 @@ def __init__( if is_mesh_device else None, ) - if add_unit_offset: - self.weight_distributed = ttnn.add(self.weight_distributed, 1.0) # Add offset to distributed weight - self.sharded_output_config = sharded_output_config self.sharded_program_config = sharded_program_config self.output_mem_config = output_mem_config