diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 7e8ebe25c460..f87191d1ab50 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -71,17 +71,29 @@ def __init__( super().__init__() def result(self, timeout=None): - if timeout is not None: - raise RuntimeError("timeout not implemented") + # Support optional timeout for awaiting the result of this future. + deadline = None if timeout is None else time.monotonic() + timeout + # Drain any futures ahead of us in the queue. while not self.done(): future, get_response = self.futures_queue.pop() - future.wait_for_response(get_response) + + # Compute remaining time for this drain step. + remaining = None if deadline is None else (deadline - time.monotonic()) + if remaining is not None and remaining <= 0: + raise TimeoutError("timeout while waiting for pending RPC futures") + future.wait_for_response(get_response, timeout=remaining) + return super().result() - def wait_for_response(self, get_response: Callable): + def wait_for_response(self, get_response: Callable, timeout: float | None = None): try: - response = self.aggregate(get_response()) + # Try calling get_response with a timeout parameter if it accepts one. + try: + response = self.aggregate(get_response(timeout=timeout)) + except TypeError: + # Fallback for callables that don't accept timeout. + response = self.aggregate(get_response()) with suppress(InvalidStateError): self.set_result(response) except Exception as e: @@ -328,12 +340,17 @@ def collective_rpc( # type: ignore[override] shutdown_event = self.shutdown_event - def get_response(): + def get_response(timeout: float | None = None): responses = [] + start = time.monotonic() for mq in response_mqs: - dequeue_timeout = ( - None if deadline is None else (deadline - time.monotonic()) - ) + # If a specific timeout was provided to this call, treat it as + # an overall timeout for collecting responses from all MQs. + if timeout is not None: + elapsed = time.monotonic() - start + dequeue_timeout = None if timeout is None else (timeout - elapsed) + else: + dequeue_timeout = None if deadline is None else (deadline - time.monotonic()) try: status, result = mq.dequeue( timeout=dequeue_timeout, cancel=shutdown_event