Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 75 additions & 18 deletions tpu_inference/runner/compilation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
TPUSupportedSamplingMetadata
from tpu_inference.layers.jax.sharding import ShardingAxisName
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.jax_intermediate_tensor import \
JaxIntermediateTensors
from tpu_inference.utils import device_array

if TYPE_CHECKING:
Expand Down Expand Up @@ -81,6 +83,8 @@ def capture_model(self) -> None:
self._precompile_backbone_with_inputs_embeds()
if self.runner.scheduler_config.async_scheduling:
self._precompile_substitute_placeholder_token()
if not self.runner.is_last_rank:
return
self._precompile_select_from_array()
self._precompile_compute_logits()
self._precompile_disagg_utils()
Expand Down Expand Up @@ -120,8 +124,15 @@ def _precompile_input_embeddings_merger(self) -> None:
num_tokens=num_tokens,
)

def _precompile_backbone_helper(self, name, *, input_ids, positions,
inputs_embeds) -> None:
def _precompile_backbone_helper(self,
name,
*,
input_ids,
positions,
inputs_embeds,
intermediate_tensors=None,
is_first_rank=True,
is_last_rank=True) -> None:
num_tokens = None
if input_ids is not None:
num_tokens = input_ids.shape[0]
Expand Down Expand Up @@ -168,10 +179,14 @@ def model_fn_wrapper(
inputs_embeds,
layer_name_to_kvcache_index,
lora_metadata,
intermediate_tensors,
is_first_rank,
is_last_rank,
):
kv_caches, hidden_states, _ = self.runner.model_fn(
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
layer_name_to_kvcache_index, lora_metadata)
layer_name_to_kvcache_index, lora_metadata,
intermediate_tensors, is_first_rank, is_last_rank)
self.runner.kv_caches = kv_caches
return hidden_states

Expand All @@ -189,6 +204,9 @@ def model_fn_wrapper(
inputs_embeds,
tuple(self.runner.layer_name_to_kvcache_index.items()),
lora_metadata,
intermediate_tensors,
is_first_rank,
is_last_rank,
num_tokens=num_tokens,
)

Expand Down Expand Up @@ -238,6 +256,7 @@ def _precompile_substitute_placeholder_token(self) -> None:
)

def _precompile_backbone_text_only(self) -> None:
hidden_size = self.runner.model_config.get_hidden_size()
for num_tokens in self.runner.num_tokens_paddings:
dp_sharding = NamedSharding(
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
Expand All @@ -247,10 +266,28 @@ def _precompile_backbone_text_only(self) -> None:
dp_sharding)
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
dp_sharding)
self._precompile_backbone_helper("backbone",
input_ids=input_ids,
positions=positions,
inputs_embeds=None)
is_first_rank = self.runner.is_first_rank
is_last_rank = self.runner.is_last_rank
if not is_first_rank:
hidden_states = self._create_dummy_tensor(
(num_tokens, hidden_size), jnp.bfloat16)
residual = self._create_dummy_tensor((num_tokens, hidden_size),
jnp.bfloat16)
intermediate_tensors = JaxIntermediateTensors(
tensors={
"hidden_states": hidden_states,
"residual": residual
})
else:
intermediate_tensors = None
self._precompile_backbone_helper(
f"worker{self.runner.rank} backbone",
input_ids=input_ids,
positions=positions,
inputs_embeds=None,
intermediate_tensors=intermediate_tensors,
is_first_rank=is_first_rank,
is_last_rank=is_last_rank)

def _precompile_backbone_with_inputs_embeds(self) -> None:
hidden_size = self.runner.model_config.get_hidden_size()
Expand All @@ -264,10 +301,28 @@ def _precompile_backbone_with_inputs_embeds(self) -> None:
else:
positions = self._create_dummy_tensor((num_tokens, ),
jnp.int32)
self._precompile_backbone_helper("backbone with embeds",
input_ids=None,
positions=positions,
inputs_embeds=inputs_embeds)
is_first_rank = self.runner.is_first_rank
is_last_rank = self.runner.is_last_rank
if not is_first_rank:
hidden_states = self._create_dummy_tensor(
(num_tokens, hidden_size), jnp.bfloat16)
residual = self._create_dummy_tensor((num_tokens, hidden_size),
jnp.bfloat16)
intermediate_tensors = JaxIntermediateTensors(
tensors={
"hidden_states": hidden_states,
"residual": residual
})
else:
intermediate_tensors = None
self._precompile_backbone_helper(
f"worker{self.runner.rank} backbone with embeds",
input_ids=None,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
is_first_rank=is_first_rank,
is_last_rank=is_last_rank)

def _precompile_select_from_array_helper(
self,
Expand Down Expand Up @@ -333,7 +388,7 @@ def _precompile_select_from_array(self) -> None:
PartitionSpec(ShardingAxisName.ATTN_DATA))
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
self._precompile_select_from_array_helper(
name="select all logits",
name=f"worker{self.runner.rank} select all logits",
source_paddings=self.runner.num_tokens_paddings,
indices_paddings=index_paddings,
hidden_dim=hsize,
Expand All @@ -344,15 +399,17 @@ def _precompile_select_from_array(self) -> None:
if self.runner.speculative_config:
vocab_size = self.runner.model_config.get_vocab_size()
self._precompile_select_from_array_helper(
name="select bonus tokens for spec decoding",
name=
f"worker{self.runner.rank} select bonus tokens for spec decoding",
source_paddings=self.runner.num_logits_paddings,
indices_paddings=self.runner.num_reqs_paddings,
hidden_dim=vocab_size,
input_sharding=NamedSharding(self.runner.mesh,
PartitionSpec(None, "model")),
)
self._precompile_select_from_array_helper(
name="select target tokens for spec decoding",
name=
f"worker{self.runner.rank} select target tokens for spec decoding",
source_paddings=self.runner.num_logits_paddings,
indices_paddings=self.runner.num_logits_paddings,
hidden_dim=vocab_size,
Expand All @@ -375,7 +432,7 @@ def _precompile_compute_logits(self) -> None:
np.array([num_reqs], dtype=np.int32)):
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
self._run_compilation(
"compute_logits",
f"worker{self.runner.rank} compute_logits",
self.runner.compute_logits_fn,
self.runner.state,
hidden_states,
Expand Down Expand Up @@ -417,7 +474,7 @@ def _precompile_sampling(self) -> None:
do_sampling=do_sampling,
)
self._run_compilation(
"sample",
f"worker{self.runner.rank} sample",
sample,
self.runner.rng_params_for_sampling,
self.runner.mesh,
Expand Down Expand Up @@ -458,7 +515,7 @@ def _precompile_gather_logprobs(self) -> None:
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
self._run_compilation(
"gather_logprobs",
f"worker{self.runner.rank} gather_logprobs",
self.runner._compute_and_gather_logprobs,
logits,
token_ids,
Expand Down Expand Up @@ -510,7 +567,7 @@ def _precompile_rejection_sampler(self) -> None:
do_sampling=do_sampling)

self._run_compilation(
compilation_name,
f"worker{self.runner.rank} {compilation_name}",
self.runner.rejection_sampler,
draft_token_ids,
num_draft_tokens,
Expand Down
41 changes: 40 additions & 1 deletion tpu_inference/runner/persistent_batch_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict

import jax
from vllm.distributed import get_pp_group
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput

from tpu_inference.logger import init_logger
Expand Down Expand Up @@ -173,15 +174,42 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput",
)

# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
num_output_tokens = req_data.num_output_tokens[i]

# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
req_index = self.input_batch.req_id_to_index.get(req_id)

if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
elif num_output_tokens < len(req_state.output_token_ids):
del req_state.output_token_ids[num_output_tokens:]
if req_index is not None:
end_idx = (self.input_batch.num_prompt_tokens[req_index] +
num_output_tokens)
self.input_batch.num_tokens[req_index] = end_idx
self.input_batch.num_tokens_no_spec[req_index] = end_idx

# Update the block IDs.
if not resumed_from_preemption:
if new_block_ids is not None:
# Append the new blocks to the existing block IDs.
Expand All @@ -194,7 +222,6 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput",
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids

req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
Expand All @@ -209,6 +236,18 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput",
self.input_batch.block_table.append_row(
new_block_ids, req_index)

# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index

# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
Expand Down
Loading