@@ -102,6 +102,53 @@ def __init__(
102102 args ,
103103 TG = args .is_galaxy ,
104104 )
105+ if f"layers.{ layer_num } .pre_feedforward_layernorm.weight" in self .state_dict :
106+ self .pre_ff_norm = DistributedNorm ( # pre_feedforward_layernorm
107+ RMSNorm (
108+ device = mesh_device ,
109+ dim = args .dim ,
110+ eps = args .norm_eps ,
111+ state_dict = state_dict ,
112+ add_unit_offset = self .args .rms_norm_add_unit_offset ,
113+ state_dict_prefix = args .get_state_dict_prefix ("" , layer_num ),
114+ weight_cache_path = None if args .dummy_weights else weight_cache_path ,
115+ weight_dtype = ttnn .bfloat16 ,
116+ weight_key = "pre_feedforward_layernorm" ,
117+ is_distributed = self .args .is_distributed_norm ,
118+ sharded_program_config = self .model_config ["SHARDED_NORM_MLP_PRGM_CFG" ],
119+ sharded_output_config = self .model_config ["SHARDED_MLP_INPUT_MEMCFG" ],
120+ ccl_topology = self .args .ccl_topology (),
121+ ),
122+ args ,
123+ TG = args .is_galaxy ,
124+ )
125+ else :
126+ # If pre_feedforward_layernorm is not in state_dict, we do not use it
127+ self .pre_ff_norm = None
128+
129+ if f"layers.{ layer_num } .post_feedforward_layernorm.weight" in self .state_dict :
130+ self .post_ff_norm = DistributedNorm ( # post_feedforward_layernorm
131+ RMSNorm (
132+ device = mesh_device ,
133+ dim = args .dim ,
134+ eps = args .norm_eps ,
135+ add_unit_offset = self .args .rms_norm_add_unit_offset ,
136+ state_dict = state_dict ,
137+ state_dict_prefix = args .get_state_dict_prefix ("" , layer_num ),
138+ weight_cache_path = None if args .dummy_weights else weight_cache_path ,
139+ weight_dtype = ttnn .bfloat16 ,
140+ weight_key = "post_feedforward_layernorm" ,
141+ is_distributed = self .args .is_distributed_norm ,
142+ sharded_program_config = self .model_config ["SHARDED_NORM_MLP_PRGM_CFG" ],
143+ sharded_output_config = self .model_config ["SHARDED_MLP_INPUT_MEMCFG" ],
144+ ccl_topology = self .args .ccl_topology (),
145+ ),
146+ args ,
147+ TG = args .is_galaxy ,
148+ )
149+ else :
150+ # If post_feedforward_layernorm is not in state_dict, we do not use it
151+ self .post_ff_norm = None
105152
106153 def forward (
107154 self ,
@@ -116,6 +163,7 @@ def forward(
116163 kv_cache = None ,
117164 ) -> ttnn .Tensor :
118165 TG = self .args .is_galaxy
166+ residual = x
119167 # x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode)
120168 skip_mem_cfg = self .model_config ["DECODE_RESIDUAL_MEMCFG" ] if mode == "decode" else ttnn .DRAM_MEMORY_CONFIG
121169 assert (
@@ -124,36 +172,53 @@ def forward(
124172 # Norms take fractured inputs and output replicated across devices
125173 attn_in = self .attention_norm (x , mode )
126174 # Attention takes replicated inputs and produces fractured outputs
175+ if self .attention .is_sliding :
176+ position_embeddings = rot_mats [1 ]
177+ else :
178+ position_embeddings = rot_mats [0 ]
179+
127180 attn_out = self .attention .forward (
128181 attn_in ,
129182 current_pos ,
130- rot_mats ,
183+ position_embeddings ,
131184 user_id ,
132185 mode ,
133186 page_table = page_table ,
134187 chunk_page_table = chunk_page_table ,
135188 chunk_start_idx = chunk_start_idx ,
136189 kv_cache = kv_cache ,
137190 )
138- # Here x and attn_out are both fractured across devices
139- h = ttnn .add (x , attn_out , memory_config = skip_mem_cfg , dtype = ttnn .bfloat16 if TG else None )
140- ttnn .deallocate (attn_out )
191+ if self .pre_ff_norm == None :
192+ attn_out = ttnn .add (x , attn_out , memory_config = skip_mem_cfg , dtype = ttnn .bfloat16 if TG else None )
193+
194+ residual = attn_out
195+
196+ hidden_states = self .ff_norm (attn_out , mode )
197+ if self .pre_ff_norm is not None :
198+ hidden_states = ttnn .add (hidden_states , residual , memory_config = skip_mem_cfg , dtype = ttnn .bfloat16 )
199+
200+ residual = hidden_states
201+
202+ hidden_states = self .pre_ff_norm (hidden_states , mode )
203+
141204 if mode == "prefill" :
142205 x .deallocate (True )
143206
144- # Norms take fractured inputs and output replicated across devices
145- ff_in = self . ff_norm ( h , mode )
207+ # ttnn.deallocate(attn_out)
208+
146209 if TG and mode == "decode" :
147- ff_in = ttnn .to_memory_config (ff_in , memory_config = self .model_config ["MLP_ACT_MEMCFG" ])
210+ hidden_states = ttnn .to_memory_config (hidden_states , memory_config = self .model_config ["MLP_ACT_MEMCFG" ])
148211 # MLP takes replicated inputs and produces fractured outputs
149- ff_out = self .feed_forward .forward (ff_in , mode )
150- # ff_out and h are both fractured across devices
212+ hidden_states = self .feed_forward .forward (hidden_states , mode )
151213 activation_dtype = self .model_config ["DECODERS_OPTIMIZATIONS" ].get_tensor_dtype (
152214 decoder_id = self .layer_num , tensor = TensorGroup .ACTIVATION
153215 )
216+ if self .post_ff_norm is not None :
217+ hidden_states = self .post_ff_norm (hidden_states , mode )
218+
154219 out = ttnn .add (
155- h ,
156- ff_out ,
220+ residual ,
221+ hidden_states ,
157222 memory_config = skip_mem_cfg ,
158223 dtype = self .args .ccl_dtype
159224 if TG and not self .args .is_distributed_norm (mode )
0 commit comments