diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 9061a64db57c..c53bc4cfd3b8 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -289,6 +289,28 @@ def compute_probs( # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # which is slow for large vocab sizes. This may cause performance issues. logits = apply_top_k_top_p(logits, top_k, top_p) + + # Apply min_p before softmax if provided + if sampling_metadata.min_p is not None and (sampling_metadata.min_p > 0).any(): + min_p = expand_batch_to_tokens( + sampling_metadata.min_p, + cu_num_draft_tokens, + num_tokens, + ) + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Reshape min_p for broadcasting + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities + # Identify valid tokens using threshold comparison + valid_token_mask = probability_values >= adjusted_min_p + valid_token_mask = valid_token_mask.bool() + # Apply mask using boolean indexing + logits = logits.masked_fill(~valid_token_mask, -float('inf')) + output_prob = logits.softmax(dim=-1, dtype=torch.float32) return output_prob diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index ce81a40ee3ae..705bc11115e8 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from vllm.v1.worker.gpu_input_batch import InputBatch +import os def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: if req_id in input_batch.min_p_reqs: # Spec decode doesn't support min_p sampling. - return False + return os.environ.get("VLLM_SPEC_DECODE_MIN_P_SUPPORTED", "false").lower() == "true" elif (req_id in input_batch.frequency_penalties_reqs or req_id in input_batch.presence_penalties_reqs or req_id in input_batch.repetition_penalties_reqs):