Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions install_vllm_tt.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash

source /home/jungwook/tt-metal_moreh/python_env/bin/activate
#source /home/jungwook/tt-metal_moreh/python_env/bin/activate

export VLLM_TARGET_DEVICE="tt"

pip uninstall -y vllm
pip install -e . --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e . --extra-index-url https://download.pytorch.org/whl/cpu
5 changes: 5 additions & 0 deletions vllm/platforms/tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def check_tt_model_supported(model):
"openai/gpt-oss-20b",
"openai/gpt-oss-120b",
"deepseek-ai/DeepSeek-R1-0528",
"Qwen/Qwen3-30B-A3B",
"Qwen/Qwen3-235B-A22B",
]
assert model in supported_models, (
Expand Down Expand Up @@ -189,6 +190,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if not arch_names[i].startswith("TT"):
arch_names[i] = "TT" + arch_names[i]

cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 128

# Setting attributes on the class level is kind of hacky, but
# it's the only way to make validate_request depend on vllm_config
# This is needed to catch incompatible requests early enough
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/tt_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# from tests.scripts.common import get_updated_device_params
import ttnn


Check failure on line 7 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:7:81: E501 Line too long (124 > 80)
# Device helpers to be shared with top-level conftest.py and other conftest.py files that will handle open/close of devices.


Expand All @@ -17,18 +17,18 @@
dispatch_core_type = new_device_params.pop("dispatch_core_type", None)
fabric_tensix_config = new_device_params.get("fabric_tensix_config", None)

if ttnn.device.is_blackhole():

Check failure on line 20 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:20:81: E501 Line too long (133 > 80)
# Only when both fabric_config and fabric_tensix_config are set, we can use ROW dispatch, otherwise force to use COL dispatch
fabric_config = new_device_params.get("fabric_config", None)
if not (fabric_config and fabric_tensix_config):
# When not both are set, force COL dispatch
if dispatch_core_axis == ttnn.DispatchCoreAxis.ROW:
logger.warning(

Check failure on line 26 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:26:81: E501 Line too long (110 > 80)
"ROW dispatch requires both fabric and tensix config, using DispatchCoreAxis.COL instead."
)
dispatch_core_axis = ttnn.DispatchCoreAxis.COL
elif fabric_config and fabric_tensix_config:
logger.warning(

Check failure on line 31 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:31:81: E501 Line too long (131 > 80)
f"Blackhole with fabric_config and fabric_tensix_config enabled, using fabric_tensix_config={fabric_tensix_config}"
)

Expand Down Expand Up @@ -58,13 +58,13 @@
except (ValueError, AttributeError) as e:
raise ValueError(f"Invalid TT_MESHDEVICE_SHAPE format '{env_shape}': expected 'rows,cols' (e.g., '2,4')")


def create_mesh_device(device_params: Optional[Dict] = None):

Check failure on line 62 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:62:81: E501 Line too long (99 > 80)
"""Create mesh device with appropriate mesh shape based on available devices.

Check failure on line 63 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (B904)

vllm/worker/tt_device.py:61:9: B904 Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling

Mesh shape selection:
- TT_MESHDEVICE_SHAPE env var (e.g., "2,4" or "1,8") if set
- Galaxy (32 devices): 4x8

Check failure on line 67 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:67:81: E501 Line too long (81 > 80)
- T3000 (8+ devices): 2x4
- Single/Few devices: 1x{num_devices}

Expand All @@ -73,7 +73,7 @@

Returns:
Initialized mesh device
"""

Check failure on line 76 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:76:81: E501 Line too long (90 > 80)
global _printed_device_info

params = dict(device_params or {})
Expand All @@ -84,7 +84,7 @@

updated_device_params = get_updated_device_params(params)
device_ids = ttnn.get_device_ids()

env_mesh_shape = parse_mesh_shape_from_env()
if env_mesh_shape:
default_mesh_shape = env_mesh_shape
Expand All @@ -111,7 +111,7 @@
if "trace_region_size" in params:
trace_mb = params["trace_region_size"] // (1024 * 1024)
print(f" Trace Region Size: {trace_mb}MB")
print(f"=" * 60)

Check failure on line 114 in vllm/worker/tt_device.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/tt_device.py:114:81: E501 Line too long (84 > 80)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created with shape {mesh_device.shape}")

Expand Down
164 changes: 158 additions & 6 deletions vllm/worker/tt_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import dataclasses
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union

Expand Down Expand Up @@ -179,8 +180,8 @@ def __init__(
self.cached_req_data: Dict[int, Dict[str, Any]] = {}
self.previous_seq_ids: Set[int] = set()

vocab_size = self.model_config.get_vocab_size()
self.logits_processor = LogitsProcessor(vocab_size,
self.vocab_size = self.model_config.get_vocab_size()
self.logits_processor = LogitsProcessor(self.vocab_size,
logits_as_input=True)
# We are relying on having our logits shaped correctly,
# as if they came from a regular vLLM model
Expand All @@ -189,6 +190,7 @@ def __init__(
# we need to fully match the relevant parts of
# SamplingMetadata.selected_token_indices logic.
self.sampler = get_sampler()
self.prefill_sequence_order: Optional[List[int]] = None

def load_model(self) -> None:
# Note: using custom TT loader
Expand Down Expand Up @@ -832,6 +834,51 @@ def _make_sampler_output(
CompletionSequenceGroupOutput(seq_outputs, None))
return SamplerOutput(sampler_outputs)

def _extract_last_token_logits(self, logits_tensor: torch.Tensor,
unpadded_batch_size: int) -> torch.Tensor:
vocab_size = self.vocab_size

if logits_tensor.ndim == 3:
return logits_tensor[:unpadded_batch_size, -1, :]

if logits_tensor.ndim == 2:
last_dim = logits_tensor.shape[-1]
if last_dim == vocab_size:
return logits_tensor[:unpadded_batch_size, :]
if last_dim % vocab_size == 0:
seq_len = last_dim // vocab_size
reshaped = logits_tensor.view(logits_tensor.shape[0], seq_len, vocab_size)
return reshaped[:unpadded_batch_size, -1, :]
return logits_tensor[:unpadded_batch_size, :]

if logits_tensor.ndim == 1:
reshaped = self._reshape_flat_logits(logits_tensor, unpadded_batch_size)
return reshaped[:unpadded_batch_size, -1, :]

raise ValueError(f"{tuple(logits_tensor.shape)=}")

def _reshape_flat_logits(self, logits_tensor: torch.Tensor,
unpadded_batch_size: int) -> torch.Tensor:
vocab_size = self.vocab_size
total = logits_tensor.numel()
remainder = total % vocab_size
if remainder != 0:
pad = vocab_size - remainder
logits_tensor = F.pad(logits_tensor, (0, pad))
total += pad

seq_prod = total // vocab_size
if seq_prod == 0:
return logits_tensor.view(0, 1, vocab_size)

reshaped = logits_tensor.view(seq_prod, 1, vocab_size)

if seq_prod < unpadded_batch_size:
repeat_factor = math.ceil(unpadded_batch_size / seq_prod)
reshaped = reshaped.repeat(repeat_factor, 1, 1)

return reshaped

def _send_prev_step_async_out(self, model_input: TTModelInput, step_idx):
if step_idx > 0:
step_output = self.cached_step_outputs.pop(0)
Expand Down Expand Up @@ -883,6 +930,9 @@ def _execute_model_single_step(self,
int), ("unpadded_batch_size must be an int")

if not is_decode:
if not self.dp_kv_cache:
self.prefill_sequence_order = list(model_input.seq_groups)

if self.dp_kv_cache:
slots_to_allocate = self.empty_slots[:model_input.
unpadded_batch_size]
Expand Down Expand Up @@ -939,6 +989,7 @@ def _execute_model_single_step(self,
if self.model_config.is_encoder_decoder:
assert self.cached_req_data


# Use encoder-decoder data from prefill step
prefill_cross_attention_masks = [
self.cached_req_data[seq_id]
Expand Down Expand Up @@ -985,6 +1036,90 @@ def _execute_model_single_step(self,
else:
enc_dec_kwargs = {}

reorder_rows_for_current_sequences = None
if (not self.dp_kv_cache
and self.prefill_sequence_order is not None
and execute_model_kwargs["tokens"].shape[0] ==
len(self.prefill_sequence_order)):
fixed_row_for_sequence_id = {
sequence_id: row_index
for row_index, sequence_id in enumerate(
self.prefill_sequence_order)
}
missing_sequence_ids = [
sequence_id for sequence_id in model_input.seq_groups
if sequence_id not in fixed_row_for_sequence_id
]
if not missing_sequence_ids:
sequence_id_to_current_row = {
sequence_id: row_index
for row_index, sequence_id in enumerate(
model_input.seq_groups)
}
reordered_tokens = torch.zeros_like(
execute_model_kwargs["tokens"])
reordered_positions = torch.zeros_like(
execute_model_kwargs["start_pos"])
reordered_page_table = torch.zeros_like(
execute_model_kwargs["page_table"])

for sequence_id, fixed_row_index in fixed_row_for_sequence_id.items(
):
current_row_index = sequence_id_to_current_row.get(
sequence_id)
if current_row_index is None:
continue
reordered_tokens[fixed_row_index] = \
execute_model_kwargs["tokens"][current_row_index]
reordered_positions[fixed_row_index] = \
execute_model_kwargs["start_pos"][current_row_index]
reordered_page_table[fixed_row_index] = \
execute_model_kwargs["page_table"][current_row_index]

execute_model_kwargs["tokens"] = reordered_tokens
execute_model_kwargs["start_pos"] = reordered_positions
execute_model_kwargs["page_table"] = reordered_page_table

if ("sampling_params" in execute_model_kwargs
and execute_model_kwargs["sampling_params"]
is not None):
sampling_params = execute_model_kwargs[
"sampling_params"]
if isinstance(sampling_params.temperature, list):
reordered_temperature = [
PADDING_TEMPERATURE
] * len(self.prefill_sequence_order)
reordered_top_k = [
PADDING_TOP_K
] * len(self.prefill_sequence_order)
reordered_top_p = [
PADDING_TOP_P
] * len(self.prefill_sequence_order)
for sequence_id, fixed_row_index in \
fixed_row_for_sequence_id.items():
current_row_index = sequence_id_to_current_row.get(
sequence_id)
if current_row_index is None:
continue
reordered_temperature[
fixed_row_index] = sampling_params.temperature[
current_row_index]
reordered_top_k[fixed_row_index] = \
sampling_params.top_k[current_row_index]
reordered_top_p[fixed_row_index] = \
sampling_params.top_p[current_row_index]
execute_model_kwargs[
"sampling_params"] = TTSamplingParams(
temperature=reordered_temperature,
top_k=reordered_top_k,
top_p=reordered_top_p)

reorder_rows_for_current_sequences = [
fixed_row_for_sequence_id[sequence_id]
for sequence_id in model_input.seq_groups
]


if self.dp_kv_cache:
# Calculate perm_table_tensor:
# perm_table_tensor[new_idx] = current_slot_idx
Expand Down Expand Up @@ -1070,12 +1205,28 @@ def _execute_model_single_step(self,
tt_out = self.model.process_decode_output_host(
tt_out, is_tokens=model_input.perform_device_sampling)
if self.dp_kv_cache:
tt_out = tt_out[perm_table_tensor]
if isinstance(tt_out, tuple):
reordered_primary = tt_out[0][perm_table_tensor]
tt_out = (reordered_primary, ) + tt_out[1:]
else:
tt_out = tt_out[perm_table_tensor]
elif reorder_rows_for_current_sequences is not None:
if isinstance(tt_out, tuple):
reordered_primary = tt_out[0][
reorder_rows_for_current_sequences]
tt_out = (reordered_primary, ) + tt_out[1:]
else:
tt_out = tt_out[reorder_rows_for_current_sequences]

def _unwrap_logits(output):
if isinstance(output, tuple):
return output[0]
return output

if model_input.compat_sampling_used:
# compat sampling is only supported on host
tt_logits = tt_out[:model_input.unpadded_batch_size,
-1, :] # [unpadded batch, vocab]
tt_logits = self._extract_last_token_logits(
_unwrap_logits(tt_out), model_input.unpadded_batch_size)
#This is coincidentally the same shape as the logits
# we would get from a regular vllm model,
# assuming we have no prompt logprobs, and one sequence per group.
Expand All @@ -1095,7 +1246,8 @@ def _execute_model_single_step(self,
else:
if not model_input.perform_device_sampling:
# unpadded batch, vocab of last token
next_logits = tt_out[:model_input.unpadded_batch_size, -1, :]
next_logits = self._extract_last_token_logits(
_unwrap_logits(tt_out), model_input.unpadded_batch_size)
assert model_input.tt_sampling_params is not None
assert isinstance(
model_input.tt_sampling_params, TTSamplingParams), (
Expand Down
30 changes: 28 additions & 2 deletions vllm/worker/tt_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def get_fabric_config(override_tt_config, num_devices):
"FABRIC_1D": ttnn.FabricConfig.FABRIC_1D,
"FABRIC_1D_RING": ttnn.FabricConfig.FABRIC_1D_RING,
"FABRIC_2D": ttnn.FabricConfig.FABRIC_2D,
"FABRIC_2D_TORUS_XY": ttnn.FabricConfig.FABRIC_2D_TORUS_XY,
"CUSTOM": ttnn.FabricConfig.CUSTOM,
}
fabric_config = fabric_config_map.get(fabric_config_str)
Expand Down Expand Up @@ -541,6 +542,7 @@ def set_fabric(override_tt_config, num_devices):
logger.info("Setting fabric config: %s, reliability mode: %s",
fabric_config, reliability_mode)
ttnn.set_fabric_config(fabric_config, reliability_mode)
return fabric_config


# From tt-metal/conftest.py:
Expand Down Expand Up @@ -619,6 +621,25 @@ def get_mesh_grid(local_dp_rank=0):
return mesh_grid


def get_dispatch_core_axis(override_tt_config):
dispatch_core_axis: ttnn.DispatchCoreAxis = ttnn.DispatchCoreAxis.ROW

if override_tt_config is None:
return dispatch_core_axis

dispatch_core_axis_config = override_tt_config.get("dispatch_core_axis", None)

if dispatch_core_axis_config is None:
return dispatch_core_axis

assert dispatch_core_axis_config in ["row", "col"], (
f"Invalid dispatch_core_axis: {dispatch_core_axis_config}. "
"Expected: row, col.")
dispatch_core_axis = (ttnn.DispatchCoreAxis.COL
if dispatch_core_axis_config == "col"
else ttnn.DispatchCoreAxis.ROW)
return dispatch_core_axis

def open_mesh_device(override_tt_config, trace_mode, local_dp_rank=0):
assert local_dp_rank == 0, "open_mesh_device must run on local DP rank 0"
mesh_grid = get_mesh_grid(local_dp_rank)
Expand All @@ -629,14 +650,19 @@ def open_mesh_device(override_tt_config, trace_mode, local_dp_rank=0):

# Set fabric before opening the device
num_devices_requested = mesh_grid[0] * mesh_grid[1]
set_fabric(override_tt_config, num_devices_requested)
fabric_config = set_fabric(override_tt_config, num_devices_requested)

# mesh_device = ttnn.open_mesh_device(
# ttnn.MeshShape(*mesh_grid),
# dispatch_core_config=get_dispatch_core_config(override_tt_config),
# **device_params,
# )
device_params = {"trace_region_size": 95449088, "fabric_config": ttnn.FabricConfig.FABRIC_1D_RING}

device_params = {"trace_region_size": 95449088}
if fabric_config:
device_params["fabric_config"] = fabric_config
device_params["dispatch_core_axis"] = get_dispatch_core_axis(override_tt_config)

mesh_device = create_mesh_device(device_params)
# set_and_get_device_cache(mesh_device)
# logger.info("multidevice with %d devices and grid %s is created",
Expand Down