@@ -245,14 +245,27 @@ def test_tt_model_acc(
245245 theta = model_args .rope_theta ,
246246 rope_scaling = model_args .rope_scaling ,
247247 )
248+
249+ if model_args .rope_local_theta is not None :
250+ # If local theta is set, use it to compute the local rope matrices
251+ rot_mats_local = get_rot_mats (
252+ head_dim = model_args .head_dim ,
253+ device = mesh_device ,
254+ seq_len = prefill_lens [0 ],
255+ theta = model_args .rope_local_theta ,
256+ rope_scaling = None ,
257+ )
258+ else :
259+ rot_mats_local = None
260+
248261 prefill_input = model_args .prepare_residual_tensor_prefill (
249262 pt_prefill_input [batch_id ],
250263 )
251264
252265 tt_out = tt_model (
253266 prefill_input ,
254267 current_pos = None ,
255- rot_mats = rot_mats_prefill ,
268+ rot_mats = [ rot_mats_prefill , rot_mats_local ] ,
256269 user_id = batch_id ,
257270 mode = "prefill" ,
258271 page_table = page_table_tt ,
@@ -280,7 +293,7 @@ def test_tt_model_acc(
280293
281294 # Get cos/sin matrices for the current position of each user
282295 rot_mats = tt_model .rope_setup .get_rot_mats (current_pos )
283-
296+ rot_mats_local = None if tt_model . rope_setup_local is None else tt_model . rope_setup . get_rot_mats ( current_pos )
284297 # Print table header
285298 if use_reference_file :
286299 logger .info (f"{ 'Progress' :<15} { 'Correct' :<8} { 'True' :<15} { 'Actual' :<15} { 'Top 5 Predictions' :<75} " )
@@ -310,7 +323,7 @@ def test_tt_model_acc(
310323 tt_out = tt_model (
311324 decode_input ,
312325 current_pos_tensor ,
313- rot_mats = rot_mats ,
326+ rot_mats = [ rot_mats , rot_mats_local ] ,
314327 mode = "decode" ,
315328 page_table = page_table_tt ,
316329 )
@@ -351,7 +364,9 @@ def test_tt_model_acc(
351364 # Update rot_mats for next iteration
352365 current_pos += 1
353366 rot_mats = tt_model .rope_setup .get_rot_mats (current_pos )
354-
367+ rot_mats_local = (
368+ tt_model .rope_setup_local .get_rot_mats (current_pos ) if tt_model .rope_setup_local is not None else None
369+ )
355370 # Modify the accuracy checking section when using reference text
356371 if not use_reference_file :
357372 # Get probabilities from model output
0 commit comments