diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 0b240b7d434e..3345ea11eccd 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -159,10 +159,8 @@ def test_eagle_correctness( attn_backend: str, ): if attn_backend == "TREE_ATTN": - # TODO: Fix this flaky test pytest.skip( - "TREE_ATTN is flaky in the test disable for now until it can be " - "resolved (see https://github.com/vllm-project/vllm/issues/22922)") + "TREE_ATTN is tested separately in test_tree_eagle_correctness.") # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) @@ -223,3 +221,83 @@ def test_eagle_correctness( del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("model_setup", [ + ("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), + ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), +], + ids=[ + "llama3_eagle", + "llama3_eagle3", + ]) +@pytest.mark.parametrize( + "spec_token_tree", + [ + [(0, )], # A single token + [(0, ), (0, 0), (0, 0, 0)], # Chain + [(0, ), (1, ), (2, )], # Parallel + [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), + (2, 1)], # Tree + ]) +def test_tree_eagle_correctness( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_setup: tuple[str, str, str, int], + spec_token_tree: list[tuple[int, ...]], +): + # Generate test prompts inside the function instead of using fixture + test_prompts = get_test_prompts(False) + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using eagle speculative decoding. + model_setup: (method, model_name, eagle_model_name, tp_size) + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", "TREE_ATTN") + method, model_name, spec_model_name, tp_size = model_setup + + ref_llm = LLM(model=model_name, + max_model_len=2048, + tensor_parallel_size=tp_size) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + spec_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": len(spec_token_tree), + "spec_token_tree": str(spec_token_tree), + "max_model_len": 2048, + }, + max_model_len=2048, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 50% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. This + # threshold is lower than the other tests because the tree attention + # backend uses triton kernels, which seem to introduce more floating + # point non-determinism when compared to FA3. + assert matches > int(0.50 * len(ref_outputs)) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 4e912f98f376..714b3a3c56a2 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any, Optional +from unittest.mock import Mock import pytest import torch @@ -11,6 +12,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) +from vllm.v1.sample.sampler import Sampler, SamplerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata DEVICE = current_platform.device_type @@ -18,7 +20,24 @@ @pytest.fixture def rejection_sampler(): - return RejectionSampler() + mock_main_sampler = Mock(spec=Sampler) + return RejectionSampler(mock_main_sampler, DEVICE) + + +def mock_main_sampler_output(rejection_sampler: RejectionSampler, + bonus_token_ids: torch.Tensor): + rejection_sampler.main_sampler.return_value = SamplerOutput( + sampled_token_ids=bonus_token_ids, logprobs_tensors=None) + + +def create_spec_decode_metadata(spec_tokens: list[list[int]], + logits: torch.Tensor) -> SpecDecodeMetadata: + metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) + metadata.target_logits_indices = torch.arange(logits.shape[0]) + # Output bonus token ids are mocked, so the bonus logit indices should + # be empty. + metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32) + return metadata def create_logits_tensor(output_token_ids: list[list[int]], @@ -83,20 +102,19 @@ def test_perfect_match(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_main_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_early_mismatch(rejection_sampler): @@ -108,14 +126,13 @@ def test_early_mismatch(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_main_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( @@ -123,7 +140,7 @@ def test_early_mismatch(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_multiple_sequences(rejection_sampler): @@ -136,20 +153,19 @@ def test_multiple_sequences(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_main_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_single_token_sequence(rejection_sampler): @@ -161,18 +177,17 @@ def test_single_token_sequence(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_main_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_empty_sequence(rejection_sampler): @@ -184,18 +199,17 @@ def test_empty_sequence(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_main_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_multiple_mismatches(rejection_sampler): @@ -208,14 +222,13 @@ def test_multiple_mismatches(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_main_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( @@ -224,7 +237,7 @@ def test_multiple_mismatches(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) @pytest.mark.parametrize( @@ -242,20 +255,19 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_main_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) - assert torch.equal(output, expected_tensor) + assert torch.equal(output.sampled_token_ids, expected_tensor) ########################### Tests for Random Sampling ################### @@ -305,17 +317,18 @@ def test_deterministic_when_seeded( sampling_metadata = create_sampling_metadata(all_greedy=False, temperature=temperature, generators=seeded_seqs) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=DEVICE) + spec_decode_metadata = create_spec_decode_metadata( + draft_token_ids.tolist(), target_logits) + + mock_main_sampler_output(rejection_sampler, bonus_token_ids) rep_result = rejection_sampler( spec_decode_metadata, - draft_probs=draft_probs, - target_logits=target_logits, - bonus_token_ids=bonus_token_ids, + draft_probs=None, + logits=target_logits, sampling_metadata=sampling_metadata, ) - results.append(rep_result) + results.append(rep_result.sampled_token_ids) for i in range(batch_size): if seeded_mask[i]: @@ -424,7 +437,9 @@ def estimate_rejection_sampling_pdf( Returns: Estimated probability distribution of the output tokens. """ - rejection_sampler = RejectionSampler() + # Mock the main_sampler that TreeRejectionSampler uses + mock_main_sampler = Mock(spec=Sampler) + rejection_sampler = RejectionSampler(mock_main_sampler, DEVICE) num_tokens = num_samples * k # Repeat draft probs num_samples * k times. draft_probs = draft_probs.reshape(1, 1, @@ -447,16 +462,17 @@ def estimate_rejection_sampling_pdf( temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE) sampling_metadata = create_sampling_metadata(all_greedy=False, temperature=temperature) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=bonus_token_ids.device) - output_token_ids = rejection_sampler( + spec_decode_metadata = create_spec_decode_metadata( + draft_token_ids.tolist(), target_logits) + + mock_main_sampler_output(rejection_sampler, bonus_token_ids) + sampler_output = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, - target_logits=target_logits, - bonus_token_ids=bonus_token_ids, + logits=target_logits, sampling_metadata=sampling_metadata, ) - output_token_ids = output_token_ids[:, :-1].flatten() + output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten() hist = torch.histogram(output_token_ids.to(dtype=torch.float, device="cpu"), @@ -496,22 +512,20 @@ def _test_masked_logits( device=DEVICE) # Create spec decode metadata - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids, - device=DEVICE, - ) + spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, + target_logits) # Run rejection sampling - output_token_ids = rejection_sampler( + mock_main_sampler_output(rejection_sampler, bonus_token_ids) + output = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, - target_logits=target_logits, - bonus_token_ids=bonus_token_ids, + logits=target_logits, sampling_metadata=sampling_metadata, ) # Remove bonus tokens and reshape - output_token_ids = output_token_ids[:, :-1].flatten().tolist() + output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist() # Check that all sampled tokens are within the unmasked indices. for i in range(num_tokens): diff --git a/tests/v1/sample/test_tree_rejection_sampler.py b/tests/v1/sample/test_tree_rejection_sampler.py new file mode 100644 index 000000000000..8c2a5d5ada44 --- /dev/null +++ b/tests/v1/sample/test_tree_rejection_sampler.py @@ -0,0 +1,405 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch + +from vllm.platforms import current_platform +from vllm.v1.sample.logits_processor import LogitsProcessors +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler +from vllm.v1.sample.tree_rejection_sampler import (PLACEHOLDER_TOKEN_ID, + TreeRejectionSampler) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.tree_spec_decode.tree_drafter_params import TreeDrafterParams + +DEVICE = current_platform.device_type +VOCAB_SIZE = 100 +Node = tuple[int, ...] + +########################### Helper Functions ########################### + + +def create_tree_rejection_sampler(tree_structure: list[Node], + batch_size: int) -> TreeRejectionSampler: + tree_drafter_params = TreeDrafterParams.from_spec_token_tree( + str(tree_structure)) + return TreeRejectionSampler( + tree_drafter_params=tree_drafter_params, + max_batch_size=batch_size, + main_sampler=Sampler(), + device=DEVICE, + ) + + +def get_token_id(tree: list[Node], node: Node) -> int: + # Token id is just the position of this node in the tree. + return tree.index(node) + + +def to_input_draft_token_ids(tree: list[Node], num_drafts: int, + draft_nodes: list[Node]) -> torch.Tensor: + """ + Creates a tensor of draft token ids to input into the rejection sampler. + Each given node is mapped to a unique token id. All other positions are + given a random token id. + """ + draft_token_ids = torch.randint( + # Offset the random token ids by the size of the tree. + low=len(tree), + high=VOCAB_SIZE, + size=(num_drafts, ), + device=DEVICE) + for draft_node in draft_nodes: + # Get the draft node's position in the tree, excluding the root node. + index = tree.index(draft_node) - 1 + # Assign unique token id to the node. + token_id = get_token_id(tree, draft_node) + draft_token_ids[index] = token_id + return draft_token_ids + + +def to_output_token_ids(tree: list[Node], num_drafts: int, + accepted: list[Node], bonus: Node) -> torch.Tensor: + """ + Creates a tensor where only the accepted and bonus nodes are mapped to + their token ids. + """ + output_token_ids = torch.empty(num_drafts + 1, device=DEVICE) + output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + for accepted_node in accepted: + index = tree.index(accepted_node) - 1 + token_id = get_token_id(tree, accepted_node) + output_token_ids[index] = token_id + output_token_ids[-1] = get_token_id(tree, bonus) + return output_token_ids + + +def create_logits_tensor(tree: list[Node], num_logits: int, + sample_map: dict[Node, Node]) -> torch.Tensor: + """ + Helper function to create logits tensor that will produce the desired + token ids on argmax + """ + logits = torch.full((num_logits, VOCAB_SIZE), -100.0, device=DEVICE) + for index in range(num_logits): + node = tree[index] + if node not in sample_map: + continue + sampled_node = sample_map[node] + token_id = get_token_id(tree, sampled_node) + logits[index, token_id] = 100.0 + return logits + + +def create_sampling_metadata( + all_greedy: bool, + temperature: Optional[torch.Tensor] = None, + top_k: Optional[torch.Tensor] = None, + top_p: Optional[torch.Tensor] = None, + generators: Optional[dict[int, Any]] = None, +) -> SamplingMetadata: + """ + Create a v1 sampling metadata object with all_greedy set to the given + value. Either all greedy or all random sampling is used. + """ + generators = generators or {} + if all_greedy: + temperature = None + else: + assert temperature is not None + + return SamplingMetadata( + temperature=temperature, + all_greedy=all_greedy, + all_random=not all_greedy, + top_p=top_p, + top_k=top_k, + generators=generators, + max_num_logprobs=0, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=torch.tensor([]), + presence_penalties=torch.tensor([]), + repetition_penalties=torch.tensor([]), + output_token_ids=[], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessors(), + ) + + +def assert_rejection_sample( + draft_tree: list[Node], + spec_nodes: list[list[Node]], + target_sample_maps: list[dict[Node, Node]], + expected_accepted_nodes: list[list[Node]], + expected_bonus_nodes: list[Node], +): + num_drafts = len(draft_tree) + # Create tree rejection sampler. + tree_rejection_sampler = create_tree_rejection_sampler( + draft_tree, len(spec_nodes)) + + # Create the bonus level. + last_level = len(draft_tree[-1]) + leaves = [node for node in draft_tree if len(node) == last_level] + bonus_level = [leaf + (0, ) for leaf in leaves] + # Create tree with root node and bonus level added. + tree = [()] + draft_tree + bonus_level + + # Convert drafted tokens mapping to tensor representation. + input_draft_token_ids = torch.stack( + [to_input_draft_token_ids(tree, num_drafts, s) for s in spec_nodes]) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + input_draft_token_ids.tolist(), device=DEVICE) + + # Generate logits that deterministically produce the given sampled + # tokens. + logits = torch.cat([ + create_logits_tensor(tree, num_drafts + 1, sample_map) + for sample_map in target_sample_maps + ]) + + # Create greedy sampling metadata. + metadata = create_sampling_metadata(all_greedy=True) + + # Rejection sample. + output = tree_rejection_sampler( + spec_decode_metadata, + draft_probs=None, + logits=logits, + sampling_metadata=metadata, + ) + + # Compare with output with expected. + expected_tokens = torch.stack([ + to_output_token_ids(tree, num_drafts, a, b) + for a, b in zip(expected_accepted_nodes, expected_bonus_nodes) + ]) + assert torch.equal(output.sampled_token_ids, expected_tokens) + + +########################### Tests ########################### + + +def test_single_node(): + """ + Test exact match for a single node. + """ + draft_tree: list[Node] = [ + (0, ), + ] + drafted_tokens: list[list[Node]] = [ + [(0, )], + ] + target_sample_maps: list[dict[Node, Node]] = [{ + (): (0, ), + (0, ): (0, 0), + }] + expected_accepted_tokens: list[list[Node]] = [ + [(0, )], + ] + expected_bonus_tokens: list[Node] = [ + (0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_chain_full_acceptance(): + draft_tree: list[Node] = [ + (0, ), + (0, 0), + (0, 0, 0), + ] + drafted_tokens: list[list[Node]] = [ + [(0, ), (0, 0), (0, 0, 0)], + ] + target_sample_maps: list[dict[Node, Node]] = [{ + (): (0, ), + (0, ): (0, 0), + (0, 0): (0, 0, 0), + (0, 0, 0): (0, 0, 0, 0) + }] + expected_accepted_tokens: list[list[Node]] = [ + [(0, ), (0, 0), (0, 0, 0)], + ] + expected_bonus_tokens: list[Node] = [ + (0, 0, 0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_chain_partial_acceptance(): + draft_tree: list[Node] = [ + (0, ), + (0, 0), + (0, 0, 0), + ] + target_sample_maps: list[dict[Node, Node]] = [{ + (): (0, ), + (0, ): (0, 0), + (0, 0): (0, 0, 0), + }] + drafted_tokens: list[list[Node]] = [ + [(0, ), (0, 0), (0, 0)], # Mismatch for final draft (expected (0,0,0)) + ] + expected_accepted_tokens: list[list[Node]] = [ + [(0, ), (0, 0)], + ] + expected_bonus_tokens: list[Node] = [ + (0, 0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_full_acceptance(): + draft_tree: list[Node] = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + drafted_tokens: list[list[Node]] = [ + [(1, ), (1, 1)], + ] + target_sample_maps: list[dict[Node, Node]] = [{ + (): (1, ), + (1, ): (1, 1), + (1, 1): (1, 1, 0), + }] + expected_accepted_tokens: list[list[Node]] = [ + [(1, ), (1, 1)], + ] + expected_bonus_tokens: list[Node] = [ + (1, 1, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_partial_acceptance(): + draft_tree: list[Node] = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + drafted_tokens: list[list[Node]] = [ + [(0, ), (0, 0)], # Mismatch for final draft (expected (0,0)) + ] + target_sample_maps: list[dict[Node, Node]] = [{ + (): (0, ), + (0, ): (0, 1), + }] + expected_accepted_tokens: list[list[Node]] = [ + [(0, )], + ] + expected_bonus_tokens: list[Node] = [ + (0, 1), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_early_rejection(): + draft_tree: list[Node] = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + drafted_tokens: list[list[Node]] = [ + [(1, ), (0, 1)], # Mismatch for the first draft (expected (0,)) + ] + target_sample_maps: list[dict[Node, Node]] = [{ + (): (0, ), + (0, ): (0, 0), + (0, 0): (0, 0, 0), + }] + expected_accepted_tokens: list[list[Node]] = [ + [], + ] + expected_bonus_tokens: list[Node] = [ + (0, ), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_full_acceptance_multiple_sequences(): + draft_tree: list[Node] = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + drafted_tokens: list[list[Node]] = [ + [(0, ), (0, 1)], # Sequence 1 + [(1, ), (1, 0)], # Sequence 2 + ] + target_sample_maps: list[dict[Node, Node]] = [{ + (): (0, ), + (0, ): (0, 1), + (0, 1): (0, 1, 0), + }, { + (): (1, ), + (1, ): (1, 0), + (1, 0): (1, 0, 0), + }] + expected_accepted_tokens: list[list[Node]] = [ + [(0, ), (0, 1)], + [(1, ), (1, 0)], + ] + expected_bonus_tokens: list[Node] = [ + (0, 1, 0), + (1, 0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_tree_partial_acceptance_multiple_sequences(): + draft_tree: list[Node] = [(0, ), (1, ), (0, 0), (0, 1), (1, 0), (1, 1)] + drafted_tokens: list[list[Node]] = [ + [(0, ), (0, 0)], # Mismatch for the second draft (expected (0,1)) + [(0, ), (0, 1)], # Mismatch for the first draft (expected (1,)) + ] + target_sample_maps: list[dict[Node, Node]] = [{ + (): (0, ), + (0, ): (0, 1), + }, { + (): (1, ), + (1, ): (1, 0), + }] + expected_accepted_tokens: list[list[Node]] = [ + [(0, )], + [], + ] + expected_bonus_tokens: list[Node] = [ + (0, 1), + (1, ), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) + + +def test_deep_tree_full_acceptance(): + draft_tree: list[Node] = [ + (0, ), + (1, ), # Level 1 + (0, 0), + (0, 1), + (1, 0), + (1, 1), # Level 2 + (0, 0, 0), + (0, 0, 1), + (0, 1, 0), + (0, 1, 1), + (1, 0, 0), + (1, 0, 1), + (1, 1, 0), + (1, 1, 1) # Level 3 + ] + drafted_tokens: list[list[Node]] = [ + [(1, ), (1, 1), (1, 1, 0)], + ] + target_sample_maps: list[dict[Node, Node]] = [{ + (): (1, ), + (0, ): (0, 1), + (1, ): (1, 1), + (0, 0): (0, 0, 0), + (1, 1): (1, 1, 0), + (1, 1, 0): (1, 1, 0, 0), + }] + expected_accepted_tokens: list[list[Node]] = [ + [(1, ), (1, 1), (1, 1, 0)], + ] + expected_bonus_tokens: list[Node] = [ + (1, 1, 0, 0), + ] + assert_rejection_sample(draft_tree, drafted_tokens, target_sample_maps, + expected_accepted_tokens, expected_bonus_tokens) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index ddedc61aae29..f0cc3d41c1ef 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -125,7 +125,7 @@ def test_prepare_inputs(): proposer = _create_proposer("eagle", 1) updated_metadata, token_indices = proposer.prepare_inputs( - common_attn_metadata, num_rejected_tokens.cpu()) + common_attn_metadata, num_rejected_tokens.cpu(), [], []) assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens) @@ -405,7 +405,9 @@ def create_deterministic_logits(token_ids): [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], # Tree ]) -def test_propose_tree(spec_token_tree): +def test_propose_tree(spec_token_tree, monkeypatch): + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TREE_ATTN") + # Get GPU device. device = torch.device(current_platform.device_type) @@ -444,7 +446,7 @@ def create_deterministic_logits(token_ids, k: int): # Mock the model forward calls. forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device), torch.zeros(total_tokens, hidden_size, device=device))] - for cu_num_drafts in proposer.cu_drafts_per_level: + for cu_num_drafts in proposer.cu_drafts_per_level[1:]: h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device) @@ -455,7 +457,7 @@ def create_deterministic_logits(token_ids, k: int): model_mock.side_effect = forward_returns # Mock the compute_logits calls. - cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level, + cu_num_drafts_tensor = torch.tensor(proposer.cu_drafts_per_level, dtype=torch.int32, device=device) logits_returns = [] diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index eacb2ad584ba..44fcf47787fa 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -120,10 +120,12 @@ def forward_attention( ) -def test_tree_attn_correctness() -> None: +def test_tree_attn_correctness(monkeypatch) -> None: torch.manual_seed(42) torch.cuda.manual_seed_all(42) + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TREE_ATTN") + device = "cuda" tree_attn_masks = { # Chain. diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 85e58c290b79..fe83fd3fb8e7 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -58,6 +58,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType, LazyLoader, common_broadcastable_dtype, random_uuid) +from vllm.v1.tree_spec_decode.tree_drafter_params import TreeDrafterParams if TYPE_CHECKING: from _typeshed import DataclassInstance @@ -1981,6 +1982,9 @@ class SpeculativeConfig: ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" + # params generated in the post-init stage for tree drafting. + tree_drafter_params: SkipValidation[TreeDrafterParams] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2241,18 +2245,22 @@ def __post_init__(self): f"num_speculative_tokens:{self.num_speculative_tokens}" f" must be divisible by {n_predict=}") - if self.speculative_token_tree is None: - # Generate chain of tokens. - self.speculative_token_tree = str([ - (i + 1) * (0, ) - for i in range(self.num_speculative_tokens) - ]) - else: - # Sort the token tree breadth-first. - tree_choices = ast.literal_eval( - self.speculative_token_tree) - self.speculative_token_tree = str( - sorted(tree_choices, key=lambda t: (len(t), t))) + if envs.VLLM_ATTENTION_BACKEND == "TREE_ATTN": + spec_token_tree = self.speculative_token_tree + if spec_token_tree is None: + # Generate chain of tokens. + spec_token_tree = str([ + (i + 1) * (0, ) + for i in range(self.num_speculative_tokens) + ]) + # Construct tree drafter params from the spec token tree. + self.tree_drafter_params = ( + TreeDrafterParams.from_spec_token_tree(spec_token_tree) + ) + num_tree_drafts = len(self.tree_drafter_params.draft_nodes) + assert (num_tree_drafts == self.num_speculative_tokens), ( + "len(speculative_token_tree) must equal " + "num_speculative_tokens.") self.draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_tp( diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 10238f36455d..a26eb4baa73f 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with TreeAttention.""" -import ast from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -168,21 +167,23 @@ def __init__( super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size - spec_config = vllm_config.speculative_config - spec_token_tree = (spec := spec_config) and spec.speculative_token_tree - tree_choices: list[tuple[int, - ...]] = (ast.literal_eval(spec_token_tree) - if spec_token_tree is not None else - [(0, )]) - # Construct the tree attention bias. - depth_counts = _get_depth_counts(tree_choices) - self.tree_attn_bias = _prepare_tree_attn_bias( - tree_choices, - depth_counts, - dtype=torch.float32, - device=device, - ) + tree_drafter_params = (spec := + spec_config) and spec.tree_drafter_params + if tree_drafter_params is None: + # Standard decoding. + self.tree_attn_bias = torch.zeros((1, 1), + dtype=torch.float32, + device=device) + else: + # Spec decoding. + tree_attn_mask = torch.tensor(tree_drafter_params.attn_mask, + device=device) + self.tree_attn_bias = torch.where(tree_attn_mask, 0, -torch.inf) + # TODO (TheEpicDolphin): Find a better way to separate prefills and + # decodes for tree attention. Currently, prefills <= + # self.tree_attn_bias.shape[0] are misclassified as decodes. + self.__class__.reorder_batch_threshold = self.tree_attn_bias.shape[0] def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -251,58 +252,6 @@ def build_for_drafting( return attn_metadata -def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: - # Count the number of choices at each depth of the tree. - depth_counts = [] - prev_depth = 0 - for path in sorted_tree_choices: - depth = len(path) - if depth != prev_depth: - depth_counts.append(0) - depth_counts[depth - 1] += 1 - prev_depth = depth - return depth_counts - - -def _prepare_tree_attn_bias( - sorted_tree_choices: list[tuple[int, ...]], - depth_counts: list[int], - dtype: Optional[torch.dtype], - device: Optional[torch.device], -) -> torch.Tensor: - # +1 comes from the additional root node. - tree_len = len(sorted_tree_choices) + 1 - tree_attn_mask = torch.full((tree_len, tree_len), - -torch.inf, - device=device, - dtype=dtype) - - # Set diagonal to all zeros. Each token should - # attend to itself. - mask_val = 0 - for i in range(tree_len): - tree_attn_mask[i, i] = mask_val - - # Set root to all zeros. All tokens attend to it. - tree_attn_mask[:, 0] = mask_val - - # Set all ancestors to zeros. - start = 0 - for i in range(len(depth_counts)): - for j in range(depth_counts[i]): - cur_tree_choice = sorted_tree_choices[start + j] - # Retrieve ancestor position. - if len(cur_tree_choice) == 1: - continue - ancestor_idx = [] - for c in range(len(cur_tree_choice) - 1): - ancestor_idx.append( - sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) - tree_attn_mask[j + start + 1, ancestor_idx] = mask_val - start += depth_counts[i] - return tree_attn_mask - - class TreeAttentionImpl(AttentionImpl): def __init__( diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 3d5e59addfcf..f956d486faca 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -2,13 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional +import numpy as np import torch import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton +from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) @@ -43,17 +46,24 @@ class RejectionSampler(nn.Module): output tokens = accepted tokens + recovered tokens + bonus tokens """ + def __init__( + self, + main_sampler: Sampler, + device: Optional[torch.device], + ): + super().__init__() + self.main_sampler = main_sampler + self.device = device + def forward( self, metadata: SpecDecodeMetadata, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] - target_logits: torch.Tensor, - # [batch_size, 1] - bonus_token_ids: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> SamplerOutput: ''' Args: metadata: @@ -62,19 +72,7 @@ def forward( Probability distribution for the draft tokens. Shape is [num_tokens, vocab_size]. Can be None if probabilities are not provided, which is the case for ngram spec decode. - target_logits (torch.Tensor): - Target model's logits probability distribution. - Shape is [num_tokens, vocab_size]. Here, probabilities from - different requests are flattened into a single tensor because - this is the shape of the output logits. - NOTE: `target_logits` can be updated in place to save memory. - bonus_token_ids (torch.Tensor): - A tensor containing bonus tokens. Shape is [batch_size, 1]. - Bonus tokens are added to the end of the sequence if all - proposed tokens are accepted. We generate the bonus tokens - outside of the rejection sampler with the default sampling - strategy. It allows for more flexibility in the sampling - process such as top_p, top_k sampling. + logits: Logits from the target model. sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): Additional metadata needed for sampling, such as temperature, top-k/top-p parameters, or other relevant information. @@ -83,6 +81,25 @@ def forward( A tensor containing the final output token IDs. ''' assert metadata.max_spec_len <= MAX_SPEC_LEN + + bonus_logits_indices = metadata.bonus_logits_indices + target_logits_indices = metadata.target_logits_indices + + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + bonus_logits = logits[bonus_logits_indices] + sampler_output = self.main_sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[target_logits_indices] + # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the # `compute_probs` function. @@ -102,13 +119,15 @@ def forward( bonus_token_ids, sampling_metadata, ) - return output_token_ids + + sampler_output.sampled_token_ids = output_token_ids + return sampler_output @staticmethod def parse_output( output_token_ids: torch.Tensor, vocab_size: int, - ) -> list[list[int]]: + ) -> tuple[list[list[int]], list[list[int]]]: """Parse the output of the rejection sampler. Args: @@ -119,17 +138,21 @@ def parse_output( vocab_size: The size of the vocabulary. Returns: - A list of lists of token IDs. + outputs: A list of lists of token IDs. + indices: A list of lists of token indices. """ output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (output_token_ids_np < vocab_size)) - outputs = [ - row[valid_mask[i]].tolist() - for i, row in enumerate(output_token_ids_np) - ] - return outputs + + outputs = [] + indices = [] + for i, row in enumerate(output_token_ids_np): + idxs = np.nonzero(valid_mask[i])[0] + outputs.append(row[idxs].tolist()) + indices.append(idxs.tolist()) + return outputs, indices def rejection_sample( diff --git a/vllm/v1/sample/tree_rejection_sampler.py b/vllm/v1/sample/tree_rejection_sampler.py new file mode 100644 index 000000000000..9a07fa3f7614 --- /dev/null +++ b/vllm/v1/sample/tree_rejection_sampler.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import numpy as np +import torch + +from vllm.logger import init_logger +from vllm.triton_utils import tl, triton +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, + RejectionSampler) +from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.tree_spec_decode.tree_drafter_params import TreeDrafterParams + +logger = init_logger(__name__) + +BLOCK_SIZE: int = 64 + + +class TreeRejectionSampler(RejectionSampler): + + def __init__( + self, + tree_drafter_params: TreeDrafterParams, + max_batch_size: int, + main_sampler: Sampler, + device: Optional[torch.device], + ): + super().__init__(main_sampler, device) + self.tree_mask = torch.tensor(tree_drafter_params.attn_mask, + device=device)[:, 1:].contiguous() + # Cumulative # of tokens per level, including the root token. + self.cu_tokens_per_level = [ + num_drafts + 1 + for num_drafts in tree_drafter_params.cu_drafts_per_level + ] + + # Get tree depth (# levels) and width (# drafts at last level). + self.tree_depth = len(self.cu_tokens_per_level) + self.tree_width = self.cu_tokens_per_level[ + -1] - self.cu_tokens_per_level[-2] + self.tree_size = self.cu_tokens_per_level[-1] + + # Get per-level slices of draft indices, and per-level indices + # for their corresponding parents. + num_children_per_level = tree_drafter_params.child_drafts_per_level + self.draft_slices = [(0, 0)] + self.parent_indices: list[list[int]] = [[]] + parents_end = 0 + for level in range(1, self.tree_depth): + # Add slice of draft indices for this level. + self.draft_slices.append((self.cu_tokens_per_level[level - 1], + self.cu_tokens_per_level[level])) + # Add indices for this level's parents. + parents_start = parents_end + parents_end = self.cu_tokens_per_level[level - 1] + num_children = num_children_per_level[level - 1] + indices = [] + for parent_idx in range(parents_start, parents_end): + indices += [parent_idx] * num_children + self.parent_indices.append(indices) + + # Precompute indices for logits corresponding to tree-internal + # tokens across batches. + self.tree_internal_size = self.cu_tokens_per_level[-2] + self.tree_index_offsets = np.arange(self.tree_size) + + def forward( + self, + metadata: SpecDecodeMetadata, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + """ + Args: + metadata: + Metadata for spec decoding. + draft_probs (Optional[torch.Tensor]): + Probability distribution for the draft tokens. Shape is + [num_tokens, vocab_size]. Can be None if probabilities are + not provided, which is the case for ngram spec decode. + logits (torch.Tensor): + Target model's logits probability distribution. + Shape is [num_tokens, vocab_size]. Here, probabilities from + different requests are flattened into a single tensor because + this is the shape of the output logits. + sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): + Additional metadata needed for sampling, such as temperature, + top-k/top-p parameters, or other relevant information. + Returns: + output_token_ids (torch.Tensor): + A tensor containing the final output token IDs. + """ + + # Example tree structure (2 levels of drafts): + # 0 (root) + # / \ + # 1 2 level 1 + # / | | \ + # 3 4 5 6 level 2 + + draft_tree_size = self.tree_size - 1 + tree_internal_index_offsets = ( + self.tree_index_offsets[:self.tree_internal_size]) + draft_index_offsets = self.tree_index_offsets[:draft_tree_size] + draft_token_ids = metadata.draft_token_ids + draft_token_ids_cpu = draft_token_ids.cpu() + + # 8 + num_reqs = len(metadata.num_draft_tokens) + # [1, 8, 8, 0, 0, 0, 0, 0] + num_draft_tokens = np.array(metadata.num_draft_tokens) + # [2, 9, 9, 1, 1, 1, 1, 1] + num_tokens = num_draft_tokens + 1 + # [0, 1, 1, 0, 0, 0, 0, 0] + is_tree_decode = num_draft_tokens == draft_tree_size + + # [0, 0, 0, 0, 0, 0, 0, 0] + start_indices = np.zeros(num_reqs, dtype=np.int32) + # [0, 2, 11, 20, 21, 22, 23, 24] + np.cumsum(num_tokens[:-1], out=start_indices[1:]) + + # Create output token ids buffer. + output_token_ids = torch.empty( + # +1 for the bonus token. + (num_reqs, draft_tree_size + 1), + dtype=torch. + int32, # Consistent with SamplerOutput.sampled_token_ids. + device=self.device, + ) + + # [0, 0, 0, 0, 0, 0, 0, 0] + accepted_index_offsets = np.zeros(num_reqs, dtype=np.int32) + + num_tree_decodes = is_tree_decode.sum() + if num_tree_decodes > 0: + # Compute target probabilities for all logits corresponding to + # internal nodes in the tree. + vocab_size = logits.shape[-1] + # [2, 11] + tree_decode_start_indices = start_indices[is_tree_decode] + # [[2, 3, 4], + # [11, 12, 13]] + tree_internal_indices = (tree_decode_start_indices[:, None] + + tree_internal_index_offsets) + # [2, 3, 4, 11, 12, 13] + tree_internal_indices_tensor = torch.from_numpy( + tree_internal_indices.flatten()).to(self.device) + target_probs = self.compute_tree_target_probs( + logits[tree_internal_indices_tensor], + is_tree_decode, + num_tree_decodes, + sampling_metadata, + ).view(num_tree_decodes, -1, vocab_size) + # Sample target token ids from the target probabilities. + # TODO(TheEpicDolphin): Add support for probabilistic-style + # rejection sampling, as used in EAGLE. + target_token_ids = target_probs.argmax(dim=-1) + + # Get the draft token ids for batches with full draft trees. + # [0, 0] + draft_start_indices = np.zeros(num_tree_decodes, dtype=np.int32) + # [1, 9] + np.cumsum(num_draft_tokens[is_tree_decode][:-1], + out=draft_start_indices[1:]) + # [[1, 2, 3, ... , 8] + # [9, 10, 11, ... , 16]] + tree_draft_indices = (draft_start_indices[:, None] + + draft_index_offsets) + # [[311, 8844, 2349, 387, 4732, 96618, 311, 334], + # [3634, 279, 323, 11, 438, 15861, 3634, 7016]] + draft_token_ids_np = draft_token_ids_cpu[tree_draft_indices].numpy( + ) + + # Move sampled target token ids to CPU. + # [[311, 6435, 96618], + # [279, 11, 15861]] + target_token_ids_np = target_token_ids.cpu().numpy() + + # For each tree decode batch, find longest path from the root node. + path_lengths = np.zeros( + # +1 for the root token. + (num_tree_decodes, draft_tree_size + 1), + dtype=np.int32) + path_lengths[:, 0] = 1 + for level in range(1, self.tree_depth): + # level 2: + # (3, 9) + start, end = self.draft_slices[level] + # [1, 1, 1, 2, 2, 2] + parent_indices = self.parent_indices[level] + # [[0, 0, 0, 0, 0, 0], + # [0, 1, 0, 1, 0, 0]] + sample_match = (draft_token_ids_np[:, start - 1:end - 1] == + target_token_ids_np[:, parent_indices]) + nonzero_length = path_lengths[:, parent_indices] > 0 + combined_mask = sample_match & nonzero_length + # [[1, 2, 0, 0, 0, 0, 0, 0, 0],-> [[1, 2, 0, 0, 0, 0, 0, 0, 0], + # [1, 0, 2, 0, 0, 0, 0, 0, 0]] [1, 0, 2, 0, 0, 0, 3, 0, 0]] + path_lengths[:, start:end][combined_mask] = level + 1 + # [1, 6, 0, 0, 0, 0, 0, 0] + accepted_index_offsets[is_tree_decode] = path_lengths.argmax( + axis=-1) + + # Calculate grid dimensions. + grid_dim_0 = num_reqs + grid_dim_1 = triton.cdiv(draft_tree_size, BLOCK_SIZE) + grid = (grid_dim_0, grid_dim_1) + + # Launch kernel to set accepted draft token ids in output buffer. + accepted_index_offsets_tensor = torch.from_numpy( + accepted_index_offsets).to(self.device, non_blocking=True) + _scatter_accepted_tokens_kernel[grid]( + draft_token_ids_ptr=draft_token_ids, + output_token_ids_ptr=output_token_ids, + accepted_offsets_ptr=accepted_index_offsets_tensor, + tree_mask_ptr=self.tree_mask, + draft_tree_size=draft_tree_size, + placeholder_token_id=PLACEHOLDER_TOKEN_ID, + block_size=BLOCK_SIZE, + ) + + # Sample and add a bonus token to the accepted paths. + # [0, 2 + 1, 11 + 6, 20, 21, 22, 23, 24] + bonus_token_indices = start_indices + accepted_index_offsets + bonus_token_indices_tensor = torch.from_numpy(bonus_token_indices).to( + self.device, non_blocking=True) + bonus_sampler_output = self.main_sampler( + logits=logits[bonus_token_indices_tensor], + sampling_metadata=sampling_metadata, + ) + output_token_ids[:, + -1] = bonus_sampler_output.sampled_token_ids.view(-1) + return SamplerOutput( + sampled_token_ids=output_token_ids, + logprobs_tensors=bonus_sampler_output.logprobs_tensors, + ) + + def compute_tree_target_probs(self, logits: torch.Tensor, + is_tree_decode: torch.Tensor, + num_tree_decodes: int, + sampling_metadata: SamplingMetadata): + if sampling_metadata.all_greedy: + return logits + + # How many times to repeat the temperature, top-k, and top-p + # for each tree-decode batch. + num_repeats = logits.shape[0] // num_tree_decodes + + assert sampling_metadata.temperature is not None + temperature = sampling_metadata.temperature[is_tree_decode] + temperature = temperature.repeat_interleave(num_repeats) + logits.div_(temperature.view(-1, 1)) + + top_k = None + if sampling_metadata.top_k is not None: + top_k = sampling_metadata.top_k[is_tree_decode] + top_k = top_k.repeat_interleave(num_repeats) + top_p = None + if sampling_metadata.top_p is not None: + top_p = sampling_metadata.top_p[is_tree_decode] + top_p = top_p.repeat_interleave(num_repeats) + logits = apply_top_k_top_p(logits, top_k, top_p) + output_probs = logits.softmax(dim=-1, dtype=torch.float32) + return output_probs + + +@triton.jit +def _scatter_accepted_tokens_kernel( + draft_token_ids_ptr, + output_token_ids_ptr, + accepted_offsets_ptr, + tree_mask_ptr, + draft_tree_size, + placeholder_token_id: tl.constexpr, + block_size: tl.constexpr, +): + """ + For batches that correspond to tree decodes, accepted token ids from + draft_token_ids are scattered to the corresponding indices in + output_token_ids. All other indices are set to placeholder_token_id. + + Whether a token from draft_token_ids is accepted or not is determined by + indexing into tree_mask via accepted_offsets. + + Args: + draft_token_ids_ptr: [num_reqs, draft_tree_size] Draft token ids. + output_token_ids_ptr: [num_reqs, draft_tree_size + 1] Output buffer. + accepted_offsets_ptr: [num_reqs] - Indices into tree_mask rows. + tree_mask_ptr: [draft_tree_size + 1, draft_tree_size] - Boolean masks + for paths to each node in the tree. + draft_tree_size: Size of draft tree. + placeholder_token_id: Placeholder token id for rejected tokens. + block_size: Block size + + Grid: (num_reqs, ceil(draft_tree_size / BLOCK_SIZE)) + """ + + req_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Get the accepted token index offset for this request. + accepted_offset = tl.load(accepted_offsets_ptr + req_idx) + + # Calculate which tokens this block processes. + block_start = block_idx * block_size + token_offsets = block_start + tl.arange(0, block_size) + token_mask = token_offsets < draft_tree_size + + # Get accepted path mask. Index as tree_mask[accepted_offset, :]. + path_mask_base = accepted_offset * draft_tree_size + accepted_path_mask = tl.load(tree_mask_ptr + path_mask_base + + token_offsets, + mask=token_mask, + other=0) + + # Load draft tokens for this request. + draft_base = req_idx * draft_tree_size + draft_tokens = tl.load(draft_token_ids_ptr + draft_base + token_offsets, + mask=token_mask, + other=placeholder_token_id) + + # Select draft tokens based on the accepted path mask. + output_tokens = tl.where(accepted_path_mask, draft_tokens, + placeholder_token_id) + + # Store to output at the same positions. + output_width = draft_tree_size + 1 + output_base = req_idx * output_width + tl.store(output_token_ids_ptr + output_base + token_offsets, + output_tokens, + mask=token_mask) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7132d507c722..328670f4a783 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ast from dataclasses import replace from importlib.util import find_spec from typing import Optional, Protocol @@ -21,12 +20,13 @@ from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, - TreeAttentionMetadataBuilder) +from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadataBuilder from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.tree_spec_decode.utils import (apply_accepted_draft_indices, + copy_kv_cache_slots) logger = init_logger(__name__) @@ -118,32 +118,22 @@ def __init__( rocm_types.append(AiterFlashAttentionMetadata) self.allowed_attn_types = tuple(rocm_types) else: - self.allowed_attn_types = (FlashAttentionMetadata, - TreeAttentionMetadata) - - # Parse the speculative token tree. - spec_token_tree = self.speculative_config.speculative_token_tree - self.tree_choices: list[tuple[int, - ...]] = ast.literal_eval(spec_token_tree) - tree_depth = len(self.tree_choices[-1]) - # Precompute per-level properties of the tree. - num_drafts_per_level = [0] * tree_depth - for node in self.tree_choices: - num_drafts_per_level[len(node) - 1] += 1 - self.cu_drafts_per_level = [num_drafts_per_level[0]] - self.child_drafts_per_level = [num_drafts_per_level[0]] - for level in range(1, tree_depth): - self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + - num_drafts_per_level[level]) - self.child_drafts_per_level.append(num_drafts_per_level[level] // - num_drafts_per_level[level - 1]) - # Precompute draft position offsets in flattened tree. - self.tree_draft_pos_offsets = torch.arange( - 1, - len(self.tree_choices) + 1, - device=device, - dtype=torch.int32, - ).repeat(max_batch_size, 1) + self.allowed_attn_types = (FlashAttentionMetadata, ) + + # Get tree drafter params. + tree_drafter_params = self.speculative_config.tree_drafter_params + self.use_tree_spec_decode = tree_drafter_params is not None + if self.use_tree_spec_decode: + self.cu_drafts_per_level = tree_drafter_params.cu_drafts_per_level + self.child_drafts_per_level = ( + tree_drafter_params.child_drafts_per_level) + # Precompute draft token positions in flattened tree. + self.flattened_tree_positions = torch.arange( + 1, + len(tree_drafter_params.draft_nodes) + 1, + device=device, + dtype=torch.int32, + ).repeat(max_batch_size, 1) def propose( self, @@ -228,7 +218,7 @@ def propose( positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] - if isinstance(attn_metadata, TreeAttentionMetadata): + if self.use_tree_spec_decode: # Draft using tree attention. draft_token_ids_list = self.propose_tree( batch_size=batch_size, @@ -361,7 +351,6 @@ def propose_tree( TreeAttentionMetadataBuilder) total_num_drafts = self.cu_drafts_per_level[0] - level_num_drafts = total_num_drafts # Sample a draft token for each child at the tree root level. num_children = self.child_drafts_per_level[0] if num_children == 1: @@ -382,14 +371,19 @@ def propose_tree( tree_hidden_states = torch.empty(0, device=self.hidden_states.device, dtype=self.hidden_states.dtype) - # Precompute the draft token positions. - flattened_draft_positions = ( + # Precompute the draft token query positions. + flattened_query_positions = ( positions.view(batch_size, -1) + - self.tree_draft_pos_offsets[:batch_size, :]) + self.flattened_tree_positions[:batch_size, :]) tree_depth = len(self.cu_drafts_per_level) - for level in range(tree_depth - 1): + for level in range(1, tree_depth - 1): + # Update the # drafts counters for the current level. + level_num_drafts = self.cu_drafts_per_level[ + level] - total_num_drafts + total_num_drafts = self.cu_drafts_per_level[level] + # Get draft positions for RoPE. - draft_positions = positions + (level + 1) + draft_positions = positions + level exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. @@ -430,7 +424,7 @@ def propose_tree( ) attn_metadata = tree_attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, - draft_index=level + 1, + draft_index=level, ) # Apply new attention metadata to all layers. @@ -446,8 +440,7 @@ def propose_tree( attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - query_positions = flattened_draft_positions[:, level:level + - query_len] + query_positions = flattened_query_positions[:, :query_len] block_numbers = query_positions // self.block_size block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) @@ -461,8 +454,7 @@ def propose_tree( # Copy inputs to buffer for cudagraph. num_tokens = attn_metadata.num_actual_tokens - input_ids = tree_input_ids.view(-1) - self.input_ids[:num_tokens] = input_ids + self.input_ids[:num_tokens] = tree_input_ids.view(-1) self.positions[:num_tokens] = tree_positions.view(-1) self.hidden_states[:num_tokens] = tree_hidden_states.view( num_tokens, -1) @@ -498,7 +490,7 @@ def propose_tree( ) # Sample a draft token for each child at the next tree level. - num_children = self.child_drafts_per_level[level + 1] + num_children = self.child_drafts_per_level[level] if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: @@ -507,17 +499,15 @@ def propose_tree( batch_size, -1) draft_token_ids_list.append(draft_token_ids) - # Update the # drafts counters for the next tree level. - level_num_drafts = self.cu_drafts_per_level[level + - 1] - total_num_drafts - total_num_drafts = self.cu_drafts_per_level[level + 1] return draft_token_ids_list def prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, # [batch_size] - num_rejected_tokens: torch.Tensor + num_rejected_tokens: torch.Tensor, + sampled_token_indices: list[list[int]], + kv_caches: list[torch.Tensor], ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ This function is used to prepare the inputs for the spec decode. @@ -607,6 +597,27 @@ def prepare_inputs( causal=True, ) + if self.use_tree_spec_decode: + # During tree spec decoding, the accepted draft tokens may not be + # consecutive. In such cases, we must recompute the token indices, + # and update the KV cache slots. + + # Compute the accepted token indices. + apply_accepted_draft_indices(sampled_token_indices, + new_query_start_loc_np, + token_indices_np) + accepted_token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) + # Copy the KVs for the accepted tokens to the beginning of + # the sequence. + copy_kv_cache_slots( + kv_caches, + common_attn_metadata.slot_mapping, + accepted_token_indices, + token_indices, + ) + return spec_common_attn_metadata, accepted_token_indices + return spec_common_attn_metadata, token_indices def load_model(self, target_model: nn.Module) -> None: diff --git a/vllm/v1/tree_spec_decode/tree_drafter_params.py b/vllm/v1/tree_spec_decode/tree_drafter_params.py new file mode 100644 index 000000000000..a4f3117b3a8e --- /dev/null +++ b/vllm/v1/tree_spec_decode/tree_drafter_params.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +from dataclasses import dataclass + +from vllm.v1.tree_spec_decode.utils import MAX_TREE_DEPTH + + +@dataclass +class TreeDrafterParams: + draft_nodes: list[tuple[int, ...]] + attn_mask: list[list[bool]] + # Cumulative number of drafts at each level. + cu_drafts_per_level: list[int] + # Number of child drafts that each token has at the given level. + child_drafts_per_level: list[int] + # Maps each draft token index to its level in the tree. + draft_levels: list[int] + + @staticmethod + def from_spec_token_tree(spec_token_tree: str) -> "TreeDrafterParams": + # Parse the speculative token tree. + draft_nodes: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) + # Sort the tree breadth-first. + draft_nodes.sort(key=lambda t: (len(t), t)) + assert len(draft_nodes + ) > 0, "speculative_token_tree must have at least one node." + # Only trees with fixed branching factor per level are + # currently supported for tree attention. + _assert_fixed_branching_factor_per_level(draft_nodes, spec_token_tree) + + tree_depth = len(draft_nodes[-1]) + 1 + assert tree_depth <= MAX_TREE_DEPTH + + # Precompute per-level properties of the tree. + num_nodes_per_level = [0] * tree_depth + num_nodes_per_level[0] = 1 + for node in draft_nodes: + num_nodes_per_level[len(node)] += 1 + + cu_drafts_per_level = [0] + child_drafts_per_level = [] + draft_levels = [] + for level in range(1, tree_depth): + cu_drafts_per_level.append(cu_drafts_per_level[-1] + + num_nodes_per_level[level]) + child_drafts_per_level.append(num_nodes_per_level[level] // + num_nodes_per_level[level - 1]) + draft_levels += [level] * num_nodes_per_level[level] + + # Construct the tree attention bias. + depth_counts = _get_depth_counts(draft_nodes) + attn_mask = _prepare_tree_attn_bias( + draft_nodes, + depth_counts, + ) + + return TreeDrafterParams( + draft_nodes=draft_nodes, + attn_mask=attn_mask, + cu_drafts_per_level=cu_drafts_per_level, + child_drafts_per_level=child_drafts_per_level, + draft_levels=draft_levels, + ) + + +def _has_fixed_branching_factor(tree_nodes, level): + """ + Checks if all nodes at the given level have the same number of children. + """ + next_level_nodes = [node for node in tree_nodes if len(node) == level + 1] + if len(next_level_nodes) == 0: + return True + + level_nodes = [node for node in tree_nodes if len(node) == level] + child_counts = [] + for parent in level_nodes: + child_counts.append( + sum(1 for child in next_level_nodes if child[:-1] == parent)) + return len(set(child_counts)) <= 1 # All counts are the same. + + +def _assert_fixed_branching_factor_per_level(tree_nodes: list[tuple[int, ...]], + spec_token_tree: str) -> None: + """ + Asserts that each level of the tree has a fixed branching factor. That is, + the number of children per node is the same within a level, but can vary + across levels. + """ + tree_depth = len(tree_nodes[-1]) + 1 + for level in range(1, tree_depth): + assert _has_fixed_branching_factor(tree_nodes, level), ( + f"speculative_token_tree '{spec_token_tree}' has variable " + f"branching at level {level}. Tree spec decoding requires " + f"a uniform number of children per level.") + + +def _get_depth_counts(sorted_draft_nodes: list[tuple[int, ...]]) -> list[int]: + """ + Counts the number of choices at each depth of the tree. + """ + depth_counts = [] + prev_depth = 0 + for path in sorted_draft_nodes: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + return depth_counts + + +def _prepare_tree_attn_bias( + sorted_draft_nodes: list[tuple[int, ...]], + depth_counts: list[int], +) -> list[list[bool]]: + # +1 comes from the additional root node. + tree_len = len(sorted_draft_nodes) + 1 + tree_attn_mask = [[False for _ in range(tree_len)] + for _ in range(tree_len)] + + mask_val = True + for i in range(tree_len): + # Set diagonal to all True. Each token should attend to itself. + tree_attn_mask[i][i] = mask_val + # Set root column to all True. All tokens attend to it. + tree_attn_mask[i][0] = mask_val + + # Set all ancestors to True. + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_tree_choice = sorted_draft_nodes[start + j] + if len(cur_tree_choice) == 1: + continue + + for c in range(len(cur_tree_choice) - 1): + ancestor_idx = sorted_draft_nodes.index( + cur_tree_choice[:c + 1]) + 1 + tree_attn_mask[j + start + 1][ancestor_idx] = mask_val + start += depth_counts[i] + return tree_attn_mask diff --git a/vllm/v1/tree_spec_decode/utils.py b/vllm/v1/tree_spec_decode/utils.py new file mode 100644 index 000000000000..dc250c22dcaa --- /dev/null +++ b/vllm/v1/tree_spec_decode/utils.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING + +import numpy as np +import torch + +from vllm.triton_utils import tl, triton + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + +HEAD_TILE_SIZE: int = 64 +MAX_TREE_DEPTH: int = 16 + + +def apply_draft_offsets( + tree_draft_offsets: list[int], + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + query_start_loc_np: np.array, + token_positions_np: np.array, +): + """ + Updates the draft token positions with their offsets (levels) in the tree. + + Args: + tree_draft_offsets: Offsets to apply to the draft token positions. + input_batch: The input batch. + scheduler_output: The scheduler output. + query_start_loc_np: Start locations of the queries for each request. + token_positions_np: Token positions. Will be updated in-place. + """ + + draft_token_offsets = np.array(tree_draft_offsets) + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): + if len(draft_token_ids) == 0: + continue + req_idx = input_batch.req_id_to_index[req_id] + start = query_start_loc_np[req_idx] + end = query_start_loc_np[req_idx + 1] + num_drafts = end - start - 1 + token_positions_np[start + 1:end] = (token_positions_np[start] + + draft_token_offsets[:num_drafts]) + + +def apply_accepted_draft_indices( + sampled_token_indices: list[list[int]], + query_start_loc_np: np.array, + token_indices_np: np.array, +): + """ + Updates token_indices_np with the sampled token indices. + + Args: + sampled_token_indices: The indices of the accepted draft tokens. + query_start_loc_np: Start locations of the queries for each request. + token_indices_np: Token indices. Will be updated in-place. + """ + + for req_idx, seq in enumerate(sampled_token_indices): + if len(seq) <= 1: + continue + start = query_start_loc_np[req_idx] + 1 + end = query_start_loc_np[req_idx + 1] + token_indices_np[start:end] = token_indices_np[start] + np.array( + seq[:-1]) + + +def copy_kv_cache_slots( + kv_caches: list[torch.Tensor], + slot_mapping: torch.Tensor, + from_token_indices: torch.Tensor, + to_token_indices: torch.Tensor, +): + """ + Copies K/Vs from from_token_indices to to_token_indices. Used for updating + the KV cache after tree rejection sampling. + + Args: + kv_caches: List of per-layer tensors, each with shape + (2, num_blocks, block_size, num_kv_heads, head_size) + slot_mapping: Tensor mapping token indices to positions in the KV + cache. + from_token_indices: Tensor containing token indices into slot_mapping + to copy from. + to_token_indices: Tensor containing token indices into slot_mapping to + copy to. + """ + if len(kv_caches) == 0: + # Nothing to do. + return + + # Get shape and stride from first kv_cache tensor. + first_kv_cache = kv_caches[0] + assert first_kv_cache.dtype == torch.bfloat16, ( + "Only bfloat16 is supported for now.") + device = first_kv_cache.device + KV, _, block_size, num_kv_heads, head_size = first_kv_cache.shape + s0, s1, s2, s3, s4 = first_kv_cache.stride() + assert KV == 2 + num_layers = len(kv_caches) + + # Prepare indices. + from_indices = from_token_indices.contiguous() + to_indices = to_token_indices.contiguous() + slot_mapping = slot_mapping.contiguous() + num_indices = from_indices.numel() + assert to_indices.numel() == num_indices + + # Create array of pointers to kv_cache tensors. + kv_cache_ptrs = torch.tensor( + [kv_cache.data_ptr() for kv_cache in kv_caches], + dtype=torch.int64, + device=device, + ) + + # Compute grid dimensions. + # Encode KV/layers/heads. + grid0 = 2 * num_layers * num_kv_heads + # Chunk size is set to the maximum tree depth to prevent a potential + # race condition of writing to while reading from the same slot. + chunk_size = MAX_TREE_DEPTH + # Chunk across indices. + grid1 = (num_indices + chunk_size - 1) // chunk_size + # Tile across head_size. + grid2 = (head_size + HEAD_TILE_SIZE - 1) // HEAD_TILE_SIZE + + # Launch single kernel for all layers. + _kv_copy_chunked[grid0, grid1, grid2]( + kv_cache_ptrs.data_ptr(), + slot_mapping.data_ptr(), + from_indices.data_ptr(), + to_indices.data_ptr(), + num_indices, + block_size, + num_kv_heads, + head_size, + s0, + s1, + s2, + s3, + s4, + head_tile_size=HEAD_TILE_SIZE, + chunk_size=chunk_size, + ) + + +@triton.jit +def _kv_copy_chunked( + kv_ptrs_ptr, + slot_mapping_ptr, + from_indices_ptr, + to_indices_ptr, + num_indices, + block_size, + num_kv_heads, + head_size, + s0, + s1, + s2, + s3, + s4, + head_tile_size: tl.constexpr, + chunk_size: tl.constexpr, +): + pid0 = tl.program_id(0) + chunk_id = tl.program_id(1) + tile_id = tl.program_id(2) + + # Compute offsets in head dimension for this tile. + tile_offsets = tile_id * head_tile_size + tl.arange(0, head_tile_size) + mask_tile = tile_offsets < head_size + + # Decode pid0 -> KV, layer, head. + tmp = pid0 // num_kv_heads + KV = tmp % 2 + layer = tmp // 2 + head = pid0 % num_kv_heads + + # Load the pointer for this layer's kv_cache from the array. + # 8 bytes per pointer (64-bit). + kv_ptr_addr = kv_ptrs_ptr + layer * 8 + kv_ptr = tl.load(kv_ptr_addr.to(tl.pointer_type(tl.uint64))) + src_ptr = kv_ptr.to(tl.pointer_type(tl.bfloat16)) + # Copying to the same tensor, so dst_ptr == src_ptr. + dst_ptr = src_ptr + + # Fixed base depending on kv/head (no layer stride since we select the + # tensor). + base_fixed = KV * s0 + head * s3 + + # Compute chunk bounds. + chunk_start = chunk_id * chunk_size + chunk_end = tl.minimum(chunk_start + chunk_size, num_indices) + + # Cast pointers to appropriate types. + from_idx_base = from_indices_ptr.to(tl.pointer_type(tl.uint64)) + to_idx_base = to_indices_ptr.to(tl.pointer_type(tl.uint64)) + slot_mapping_base = slot_mapping_ptr.to(tl.pointer_type(tl.uint64)) + # Cast block_size to uint64 to match slot type. + block_size_u64 = block_size.to(tl.uint64) + + for idx in range(chunk_start, chunk_end): + # Load indices into slot_mapping. + from_idx = tl.load(from_idx_base + idx) + to_idx = tl.load(to_idx_base + idx) + + # Load actual KV cache slots from slot_mapping. + from_slot = tl.load(slot_mapping_base + from_idx) + to_slot = tl.load(slot_mapping_base + to_idx) + + # Skip copying if from_slot == to_slot. + should_copy = from_slot != to_slot + + # Convert slots to blocks and offsets. + from_block = from_slot // block_size_u64 + from_offset = from_slot % block_size_u64 + to_block = to_slot // block_size_u64 + to_offset = to_slot % block_size_u64 + + src_base = base_fixed + from_block * s1 + from_offset * s2 + dst_base = base_fixed + to_block * s1 + to_offset * s2 + + src_addr = (src_ptr + src_base + tile_offsets * s4).to( + tl.pointer_type(tl.bfloat16)) + dst_addr = (dst_ptr + dst_base + tile_offsets * s4).to( + tl.pointer_type(tl.bfloat16)) + + copy_mask = mask_tile & should_copy + vals = tl.load(src_addr, mask=copy_mask, other=0) + tl.store(dst_addr, vals, mask=copy_mask) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6572e421b65b..a6515c205439 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -79,10 +79,12 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler +from vllm.v1.sample.tree_rejection_sampler import TreeRejectionSampler from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.tree_spec_decode.utils import apply_draft_offsets from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.kv_connector_model_runner_mixin import ( @@ -241,6 +243,12 @@ def __init__( # mm_hash -> encoder_output self.encoder_cache: dict[str, torch.Tensor] = {} + # Tree spec decoding. + self.tree_drafter_params = ( + self.speculative_config + and self.speculative_config.tree_drafter_params) or None + self.use_tree_spec_decode = self.tree_drafter_params is not None + self.use_aux_hidden_state_outputs = False # Set up speculative decoding. # NOTE(Jiayi): currently we put the entire draft model on @@ -261,7 +269,18 @@ def __init__( else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") - self.rejection_sampler = RejectionSampler() + if self.use_tree_spec_decode: + self.rejection_sampler = TreeRejectionSampler( + self.speculative_config.tree_drafter_params, + max_batch_size=self.max_num_reqs, + main_sampler=self.sampler, + device=self.device, + ) + else: + self.rejection_sampler = RejectionSampler( + main_sampler=self.sampler, + device=self.device, + ) # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -938,6 +957,18 @@ def _prepare_inputs( self.query_start_loc.copy_to_gpu() query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + if self.use_tree_spec_decode: + assert self.tree_drafter_params is not None + # During tree spec decoding, token positions need to be offset + # by their levels in the draft tree. + apply_draft_offsets( + self.tree_drafter_params.draft_levels, + self.input_batch, + scheduler_output, + self.query_start_loc.np, + positions_np, + ) + self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) @@ -1839,31 +1870,15 @@ def _sample( sampling_metadata=sampling_metadata, ) else: - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. assert logits is not None - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( + sampler_output = self.rejection_sampler( spec_decode_metadata, None, # draft_probs - target_logits, - bonus_token_ids, + logits, sampling_metadata, ) - sampler_output.sampled_token_ids = output_token_ids - self._update_states_after_model_execute(output_token_ids) + self._update_states_after_model_execute( + sampler_output.sampled_token_ids) return sampler_output @@ -1875,6 +1890,7 @@ def _bookkeeping_sync( dict[str, int], Optional[LogprobsLists], list[list[int]], + list[list[int]], dict[str, Optional[LogprobsTensors]], list[str], dict[str, int], @@ -1929,17 +1945,24 @@ def _bookkeeping_sync( if max_gen_len == 1: # No spec decode tokens. valid_sampled_token_ids = self._to_list(sampled_token_ids) + valid_sampled_token_indices = [ + list(range(len(seq))) for seq in valid_sampled_token_ids + ] else: # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) + valid_sampled_token_ids, valid_sampled_token_indices = ( + self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + )) + # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() + valid_sampled_token_indices[i].clear() else: valid_sampled_token_ids = [] + valid_sampled_token_indices = [] invalid_req_indices = list(discard_sampled_tokens_req_indices) invalid_req_indices_set = set(invalid_req_indices) assert sampled_token_ids.shape[-1] == 1 @@ -1992,6 +2015,7 @@ def _bookkeeping_sync( num_nans_in_logits, logprobs_lists, valid_sampled_token_ids, + valid_sampled_token_indices, prompt_logprobs_dict, req_ids_output_copy, req_id_to_index_output_copy, @@ -2124,6 +2148,7 @@ def execute_model( num_nans_in_logits, logprobs_lists, valid_sampled_token_ids, + valid_sampled_token_indices, prompt_logprobs_dict, req_ids_output_copy, req_id_to_index_output_copy, @@ -2138,6 +2163,7 @@ def execute_model( self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, + valid_sampled_token_indices, self.input_batch.sampling_metadata, hidden_states, sample_hidden_states, @@ -2185,6 +2211,7 @@ def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", sampled_token_ids: list[list[int]], + sampled_token_indices: list[list[int]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, @@ -2261,7 +2288,8 @@ def propose_draft_token_ids( dtype=torch.int32) common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( - common_attn_metadata, num_rejected_tokens_cpu) + common_attn_metadata, num_rejected_tokens_cpu, + sampled_token_indices, self.kv_caches) target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. @@ -2894,21 +2922,14 @@ def _dummy_sampler_run( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn(num_tokens, - logits.shape[-1], - device=self.device, - dtype=logits.dtype) - # NOTE(woosuk): Here, we should use int32 because the sampler uses - # int32 for bonus_token_ids. If the dtype mismatches, re-compilation - # will occur at runtime. - bonus_token_ids = torch.zeros(num_reqs, - device=self.device, - dtype=torch.int32) + logits = torch.randn(num_tokens + num_reqs, + logits.shape[-1], + device=self.device, + dtype=logits.dtype) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, - target_logits, - bonus_token_ids, + logits, dummy_metadata, ) return sampler_output