From fa44d16f17777908f1a5b7bfc2ba277aeea46c29 Mon Sep 17 00:00:00 2001 From: chanhopark1 Date: Wed, 3 Dec 2025 23:51:13 +0000 Subject: [PATCH 1/8] Fix patch for lm eval --- vllm/platforms/tt.py | 5 +++ vllm/worker/tt_model_runner.py | 62 +++++++++++++++++++++++++++++++--- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/vllm/platforms/tt.py b/vllm/platforms/tt.py index 9e31d661224c..7ea41d964013 100644 --- a/vllm/platforms/tt.py +++ b/vllm/platforms/tt.py @@ -134,6 +134,7 @@ def check_tt_model_supported(model): "openai/gpt-oss-20b", "openai/gpt-oss-120b", "deepseek-ai/DeepSeek-R1-0528", + "Qwen/Qwen3-30B-A3B", "Qwen/Qwen3-235B-A22B", ] assert model in supported_models, ( @@ -189,6 +190,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if not arch_names[i].startswith("TT"): arch_names[i] = "TT" + arch_names[i] + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 128 + # Setting attributes on the class level is kind of hacky, but # it's the only way to make validate_request depend on vllm_config # This is needed to catch incompatible requests early enough diff --git a/vllm/worker/tt_model_runner.py b/vllm/worker/tt_model_runner.py index 4dfb562c392f..b4a5ad33710f 100644 --- a/vllm/worker/tt_model_runner.py +++ b/vllm/worker/tt_model_runner.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +import math from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Set, Type, Union @@ -179,8 +180,8 @@ def __init__( self.cached_req_data: Dict[int, Dict[str, Any]] = {} self.previous_seq_ids: Set[int] = set() - vocab_size = self.model_config.get_vocab_size() - self.logits_processor = LogitsProcessor(vocab_size, + self.vocab_size = self.model_config.get_vocab_size() + self.logits_processor = LogitsProcessor(self.vocab_size, logits_as_input=True) # We are relying on having our logits shaped correctly, # as if they came from a regular vLLM model @@ -832,6 +833,51 @@ def _make_sampler_output( CompletionSequenceGroupOutput(seq_outputs, None)) return SamplerOutput(sampler_outputs) + def _extract_last_token_logits(self, logits_tensor: torch.Tensor, + unpadded_batch_size: int) -> torch.Tensor: + vocab_size = self.vocab_size + + if logits_tensor.ndim == 3: + return logits_tensor[:unpadded_batch_size, -1, :] + + if logits_tensor.ndim == 2: + last_dim = logits_tensor.shape[-1] + if last_dim == vocab_size: + return logits_tensor[:unpadded_batch_size, :] + if last_dim % vocab_size == 0: + seq_len = last_dim // vocab_size + reshaped = logits_tensor.view(logits_tensor.shape[0], seq_len, vocab_size) + return reshaped[:unpadded_batch_size, -1, :] + return logits_tensor[:unpadded_batch_size, :] + + if logits_tensor.ndim == 1: + reshaped = self._reshape_flat_logits(logits_tensor, unpadded_batch_size) + return reshaped[:unpadded_batch_size, -1, :] + + raise ValueError(f"{tuple(logits_tensor.shape)=}") + + def _reshape_flat_logits(self, logits_tensor: torch.Tensor, + unpadded_batch_size: int) -> torch.Tensor: + vocab_size = self.vocab_size + total = logits_tensor.numel() + remainder = total % vocab_size + if remainder != 0: + pad = vocab_size - remainder + logits_tensor = F.pad(logits_tensor, (0, pad)) + total += pad + + seq_prod = total // vocab_size + if seq_prod == 0: + return logits_tensor.view(0, 1, vocab_size) + + reshaped = logits_tensor.view(seq_prod, 1, vocab_size) + + if seq_prod < unpadded_batch_size: + repeat_factor = math.ceil(unpadded_batch_size / seq_prod) + reshaped = reshaped.repeat(repeat_factor, 1, 1) + + return reshaped + def _send_prev_step_async_out(self, model_input: TTModelInput, step_idx): if step_idx > 0: step_output = self.cached_step_outputs.pop(0) @@ -1072,10 +1118,15 @@ def _execute_model_single_step(self, if self.dp_kv_cache: tt_out = tt_out[perm_table_tensor] + def _unwrap_logits(output): + if isinstance(output, tuple): + return output[0] + return output + if model_input.compat_sampling_used: # compat sampling is only supported on host - tt_logits = tt_out[:model_input.unpadded_batch_size, - -1, :] # [unpadded batch, vocab] + tt_logits = self._extract_last_token_logits( + _unwrap_logits(tt_out), model_input.unpadded_batch_size) #This is coincidentally the same shape as the logits # we would get from a regular vllm model, # assuming we have no prompt logprobs, and one sequence per group. @@ -1095,7 +1146,8 @@ def _execute_model_single_step(self, else: if not model_input.perform_device_sampling: # unpadded batch, vocab of last token - next_logits = tt_out[:model_input.unpadded_batch_size, -1, :] + next_logits = self._extract_last_token_logits( + _unwrap_logits(tt_out), model_input.unpadded_batch_size) assert model_input.tt_sampling_params is not None assert isinstance( model_input.tt_sampling_params, TTSamplingParams), ( From 8b518ec64d43a707cb7a87df6ca243958a914e49 Mon Sep 17 00:00:00 2001 From: chanhopark1 Date: Fri, 5 Dec 2025 00:04:10 +0000 Subject: [PATCH 2/8] Fix not forcing 1D_RING --- vllm/worker/tt_worker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tt_worker.py b/vllm/worker/tt_worker.py index 8a0600eba90d..4fad34459a2f 100644 --- a/vllm/worker/tt_worker.py +++ b/vllm/worker/tt_worker.py @@ -541,6 +541,7 @@ def set_fabric(override_tt_config, num_devices): logger.info("Setting fabric config: %s, reliability mode: %s", fabric_config, reliability_mode) ttnn.set_fabric_config(fabric_config, reliability_mode) + return fabric_config # From tt-metal/conftest.py: @@ -629,14 +630,16 @@ def open_mesh_device(override_tt_config, trace_mode, local_dp_rank=0): # Set fabric before opening the device num_devices_requested = mesh_grid[0] * mesh_grid[1] - set_fabric(override_tt_config, num_devices_requested) + fabric_config = set_fabric(override_tt_config, num_devices_requested) # mesh_device = ttnn.open_mesh_device( # ttnn.MeshShape(*mesh_grid), # dispatch_core_config=get_dispatch_core_config(override_tt_config), # **device_params, # ) - device_params = {"trace_region_size": 95449088, "fabric_config": ttnn.FabricConfig.FABRIC_1D_RING} + device_params = {"trace_region_size": 95449088} + if fabric_config: + device_params["fabric_config"] = fabric_config mesh_device = create_mesh_device(device_params) # set_and_get_device_cache(mesh_device) # logger.info("multidevice with %d devices and grid %s is created", From dfd58719e586f0185eadf05e14d061a071172c2d Mon Sep 17 00:00:00 2001 From: chanhopark1 Date: Wed, 31 Dec 2025 01:04:51 +0000 Subject: [PATCH 3/8] Fix removing reordering --- install_vllm_tt.sh | 4 +- vllm/worker/tt_model_runner.py | 114 ++++++++++++++++++++++++++++++++- 2 files changed, 115 insertions(+), 3 deletions(-) diff --git a/install_vllm_tt.sh b/install_vllm_tt.sh index 8a7892c51617..4b1b9259d524 100755 --- a/install_vllm_tt.sh +++ b/install_vllm_tt.sh @@ -1,8 +1,8 @@ #!/bin/bash -source /home/jungwook/tt-metal_moreh/python_env/bin/activate +#source /home/jungwook/tt-metal_moreh/python_env/bin/activate export VLLM_TARGET_DEVICE="tt" pip uninstall -y vllm -pip install -e . --extra-index-url https://download.pytorch.org/whl/cpu \ No newline at end of file +pip install -e . --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/vllm/worker/tt_model_runner.py b/vllm/worker/tt_model_runner.py index b4a5ad33710f..49a19c583de7 100644 --- a/vllm/worker/tt_model_runner.py +++ b/vllm/worker/tt_model_runner.py @@ -1031,6 +1031,107 @@ def _execute_model_single_step(self, else: enc_dec_kwargs = {} + reorder_rows_for_current_sequences = None + if (not self.dp_kv_cache + and self.prefill_sequence_order is not None + and execute_model_kwargs["tokens"].shape[0] == + len(self.prefill_sequence_order)): + fixed_row_for_sequence_id = { + sequence_id: row_index + for row_index, sequence_id in enumerate( + self.prefill_sequence_order) + } + missing_sequence_ids = [ + sequence_id for sequence_id in model_input.seq_groups + if sequence_id not in fixed_row_for_sequence_id + ] + if not missing_sequence_ids: + sequence_id_to_current_row = { + sequence_id: row_index + for row_index, sequence_id in enumerate( + model_input.seq_groups) + } + reordered_tokens = torch.zeros_like( + execute_model_kwargs["tokens"]) + reordered_positions = torch.zeros_like( + execute_model_kwargs["start_pos"]) + reordered_page_table = torch.zeros_like( + execute_model_kwargs["page_table"]) + + for sequence_id, fixed_row_index in fixed_row_for_sequence_id.items( + ): + current_row_index = sequence_id_to_current_row.get( + sequence_id) + if current_row_index is None: + continue + reordered_tokens[fixed_row_index] = \ + execute_model_kwargs["tokens"][current_row_index] + reordered_positions[fixed_row_index] = \ + execute_model_kwargs["start_pos"][current_row_index] + reordered_page_table[fixed_row_index] = \ + execute_model_kwargs["page_table"][current_row_index] + + execute_model_kwargs["tokens"] = reordered_tokens + execute_model_kwargs["start_pos"] = reordered_positions + execute_model_kwargs["page_table"] = reordered_page_table + + if ("sampling_params" in execute_model_kwargs + and execute_model_kwargs["sampling_params"] + is not None): + sampling_params = execute_model_kwargs[ + "sampling_params"] + if isinstance(sampling_params.temperature, list): + reordered_temperature = [ + PADDING_TEMPERATURE + ] * len(self.prefill_sequence_order) + reordered_top_k = [ + PADDING_TOP_K + ] * len(self.prefill_sequence_order) + reordered_top_p = [ + PADDING_TOP_P + ] * len(self.prefill_sequence_order) + for sequence_id, fixed_row_index in \ + fixed_row_for_sequence_id.items(): + current_row_index = sequence_id_to_current_row.get( + sequence_id) + if current_row_index is None: + continue + reordered_temperature[ + fixed_row_index] = sampling_params.temperature[ + current_row_index] + reordered_top_k[fixed_row_index] = \ + sampling_params.top_k[current_row_index] + reordered_top_p[fixed_row_index] = \ + sampling_params.top_p[current_row_index] + execute_model_kwargs[ + "sampling_params"] = TTSamplingParams( + temperature=reordered_temperature, + top_k=reordered_top_k, + top_p=reordered_top_p) + + reorder_rows_for_current_sequences = [ + fixed_row_for_sequence_id[sequence_id] + for sequence_id in model_input.seq_groups + ] + + debug_decode = self._should_log_page_table("decode") + if debug_decode: + block_tables = execute_model_kwargs["page_table"] + active_slots = None + if self.dp_kv_cache: + active_slots = [ + self.seq_groups_to_batch_slot[s] + for s in model_input.seq_groups + ] + logger.warning( + "PAGE_TABLE_DEBUG decode-pre batch=%s block_table_shape=%s seq_groups_sample=%s active_slots_sample=%s", + block_tables.shape[0], + tuple(block_tables.shape), + model_input.seq_groups[:DEBUG_PAGE_TABLE_ROWS], + active_slots[:DEBUG_PAGE_TABLE_ROWS] if active_slots is not None else None, + ) + self._log_page_table_rows("decode-pre", block_tables, model_input.seq_groups, active_slots) + if self.dp_kv_cache: # Calculate perm_table_tensor: # perm_table_tensor[new_idx] = current_slot_idx @@ -1116,7 +1217,18 @@ def _execute_model_single_step(self, tt_out = self.model.process_decode_output_host( tt_out, is_tokens=model_input.perform_device_sampling) if self.dp_kv_cache: - tt_out = tt_out[perm_table_tensor] + if isinstance(tt_out, tuple): + reordered_primary = tt_out[0][perm_table_tensor] + tt_out = (reordered_primary, ) + tt_out[1:] + else: + tt_out = tt_out[perm_table_tensor] + elif reorder_rows_for_current_sequences is not None: + if isinstance(tt_out, tuple): + reordered_primary = tt_out[0][ + reorder_rows_for_current_sequences] + tt_out = (reordered_primary, ) + tt_out[1:] + else: + tt_out = tt_out[reorder_rows_for_current_sequences] def _unwrap_logits(output): if isinstance(output, tuple): From 9703783c4a54cb675f4776ad2f99d7481586c691 Mon Sep 17 00:00:00 2001 From: chanhopark1 Date: Wed, 31 Dec 2025 01:10:57 +0000 Subject: [PATCH 4/8] Fix intializing value --- vllm/worker/tt_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/tt_model_runner.py b/vllm/worker/tt_model_runner.py index 49a19c583de7..bb782a729165 100644 --- a/vllm/worker/tt_model_runner.py +++ b/vllm/worker/tt_model_runner.py @@ -190,6 +190,7 @@ def __init__( # we need to fully match the relevant parts of # SamplingMetadata.selected_token_indices logic. self.sampler = get_sampler() + self.prefill_sequence_order: Optional[List[int]] = None def load_model(self) -> None: # Note: using custom TT loader From 465ce4f52c203ee1c696347c99508625de63ca3d Mon Sep 17 00:00:00 2001 From: chanhopark1 Date: Wed, 31 Dec 2025 01:16:50 +0000 Subject: [PATCH 5/8] Fix removing debug --- vllm/worker/tt_model_runner.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/vllm/worker/tt_model_runner.py b/vllm/worker/tt_model_runner.py index bb782a729165..93f51d24008c 100644 --- a/vllm/worker/tt_model_runner.py +++ b/vllm/worker/tt_model_runner.py @@ -1115,23 +1115,6 @@ def _execute_model_single_step(self, for sequence_id in model_input.seq_groups ] - debug_decode = self._should_log_page_table("decode") - if debug_decode: - block_tables = execute_model_kwargs["page_table"] - active_slots = None - if self.dp_kv_cache: - active_slots = [ - self.seq_groups_to_batch_slot[s] - for s in model_input.seq_groups - ] - logger.warning( - "PAGE_TABLE_DEBUG decode-pre batch=%s block_table_shape=%s seq_groups_sample=%s active_slots_sample=%s", - block_tables.shape[0], - tuple(block_tables.shape), - model_input.seq_groups[:DEBUG_PAGE_TABLE_ROWS], - active_slots[:DEBUG_PAGE_TABLE_ROWS] if active_slots is not None else None, - ) - self._log_page_table_rows("decode-pre", block_tables, model_input.seq_groups, active_slots) if self.dp_kv_cache: # Calculate perm_table_tensor: From eba796e845ba074f81e82e021935aebe7a350db3 Mon Sep 17 00:00:00 2001 From: chanhopark1 Date: Wed, 31 Dec 2025 01:25:11 +0000 Subject: [PATCH 6/8] Fix reverting vllm output --- vllm/worker/tt_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/worker/tt_model_runner.py b/vllm/worker/tt_model_runner.py index 93f51d24008c..dab86f81d5c4 100644 --- a/vllm/worker/tt_model_runner.py +++ b/vllm/worker/tt_model_runner.py @@ -930,6 +930,9 @@ def _execute_model_single_step(self, int), ("unpadded_batch_size must be an int") if not is_decode: + if not self.dp_kv_cache: + self.prefill_sequence_order = list(model_input.seq_groups) + if self.dp_kv_cache: slots_to_allocate = self.empty_slots[:model_input. unpadded_batch_size] @@ -986,6 +989,7 @@ def _execute_model_single_step(self, if self.model_config.is_encoder_decoder: assert self.cached_req_data + # Use encoder-decoder data from prefill step prefill_cross_attention_masks = [ self.cached_req_data[seq_id] From a40f16ab38d6930edba24d56d0ec7bfe7c73f1c2 Mon Sep 17 00:00:00 2001 From: namhyeong-kim <141107133+namhyeong-kim@users.noreply.github.com> Date: Thu, 15 Jan 2026 10:53:25 +0900 Subject: [PATCH 7/8] Merge pull request #7 from moreh-dev/namhyeong/LL namhyeong/LL --- vllm/worker/tt_device.py | 2 +- vllm/worker/tt_worker.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tt_device.py b/vllm/worker/tt_device.py index d8150b69ed83..893e9dab9269 100644 --- a/vllm/worker/tt_device.py +++ b/vllm/worker/tt_device.py @@ -84,7 +84,7 @@ def create_mesh_device(device_params: Optional[Dict] = None): updated_device_params = get_updated_device_params(params) device_ids = ttnn.get_device_ids() - + env_mesh_shape = parse_mesh_shape_from_env() if env_mesh_shape: default_mesh_shape = env_mesh_shape diff --git a/vllm/worker/tt_worker.py b/vllm/worker/tt_worker.py index 4fad34459a2f..afa41e097d7b 100644 --- a/vllm/worker/tt_worker.py +++ b/vllm/worker/tt_worker.py @@ -505,6 +505,8 @@ def get_fabric_config(override_tt_config, num_devices): "FABRIC_1D_RING": ttnn.FabricConfig.FABRIC_1D_RING, "FABRIC_2D": ttnn.FabricConfig.FABRIC_2D, "CUSTOM": ttnn.FabricConfig.CUSTOM, + "FABRIC_2D_DYNAMIC": ttnn.FabricConfig.FABRIC_2D_DYNAMIC, + "FABRIC_2D_DYNAMIC_TORUS_XY": ttnn.FabricConfig.FABRIC_2D_DYNAMIC_TORUS_XY, } fabric_config = fabric_config_map.get(fabric_config_str) assert fabric_config is not None, ( @@ -620,6 +622,25 @@ def get_mesh_grid(local_dp_rank=0): return mesh_grid +def get_dispatch_core_axis(override_tt_config): + dispatch_core_axis: ttnn.DispatchCoreAxis = ttnn.DispatchCoreAxis.ROW + + if override_tt_config is None: + return dispatch_core_axis + + dispatch_core_axis_config = override_tt_config.get("dispatch_core_axis", None) + + if dispatch_core_axis_config is None: + return dispatch_core_axis + + assert dispatch_core_axis_config in ["row", "col"], ( + f"Invalid dispatch_core_axis: {dispatch_core_axis_config}. " + "Expected: row, col.") + dispatch_core_axis = (ttnn.DispatchCoreAxis.COL + if dispatch_core_axis_config == "col" + else ttnn.DispatchCoreAxis.ROW) + return dispatch_core_axis + def open_mesh_device(override_tt_config, trace_mode, local_dp_rank=0): assert local_dp_rank == 0, "open_mesh_device must run on local DP rank 0" mesh_grid = get_mesh_grid(local_dp_rank) @@ -637,9 +658,12 @@ def open_mesh_device(override_tt_config, trace_mode, local_dp_rank=0): # dispatch_core_config=get_dispatch_core_config(override_tt_config), # **device_params, # ) + device_params = {"trace_region_size": 95449088} if fabric_config: device_params["fabric_config"] = fabric_config + device_params["dispatch_core_axis"] = get_dispatch_core_axis(override_tt_config) + mesh_device = create_mesh_device(device_params) # set_and_get_device_cache(mesh_device) # logger.info("multidevice with %d devices and grid %s is created", From 5ea6f7194089f92d9da03961681c64d833da98ae Mon Sep 17 00:00:00 2001 From: Heehoon Kim Date: Wed, 21 Jan 2026 03:26:16 +0000 Subject: [PATCH 8/8] Reflect removal of DYNAMIC in FabricConfig --- vllm/worker/tt_worker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/tt_worker.py b/vllm/worker/tt_worker.py index afa41e097d7b..8aa11a62f12c 100644 --- a/vllm/worker/tt_worker.py +++ b/vllm/worker/tt_worker.py @@ -504,9 +504,8 @@ def get_fabric_config(override_tt_config, num_devices): "FABRIC_1D": ttnn.FabricConfig.FABRIC_1D, "FABRIC_1D_RING": ttnn.FabricConfig.FABRIC_1D_RING, "FABRIC_2D": ttnn.FabricConfig.FABRIC_2D, + "FABRIC_2D_TORUS_XY": ttnn.FabricConfig.FABRIC_2D_TORUS_XY, "CUSTOM": ttnn.FabricConfig.CUSTOM, - "FABRIC_2D_DYNAMIC": ttnn.FabricConfig.FABRIC_2D_DYNAMIC, - "FABRIC_2D_DYNAMIC_TORUS_XY": ttnn.FabricConfig.FABRIC_2D_DYNAMIC_TORUS_XY, } fabric_config = fabric_config_map.get(fabric_config_str) assert fabric_config is not None, (