From 86c38640d4a4fc21179997ba613305be78cf8744 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 28 Nov 2025 23:31:58 +0000 Subject: [PATCH] remove v0-related components such as kv pipe and kv lookup buffer Signed-off-by: KuntaiDu --- tests/kv_transfer/test_lookup_buffer.py | 160 ---------- tests/kv_transfer/test_lookup_buffer.sh | 8 - tests/kv_transfer/test_module.py | 62 ---- tests/kv_transfer/test_send_recv.py | 154 --------- tests/kv_transfer/test_send_recv.sh | 9 - .../kv_transfer/kv_lookup_buffer/__init__.py | 0 .../kv_transfer/kv_lookup_buffer/base.py | 179 ----------- .../kv_lookup_buffer/mooncake_store.py | 164 ---------- .../kv_lookup_buffer/simple_buffer.py | 242 -------------- .../kv_transfer/kv_pipe/__init__.py | 0 vllm/distributed/kv_transfer/kv_pipe/base.py | 66 ---- .../kv_transfer/kv_pipe/mooncake_pipe.py | 295 ------------------ .../kv_transfer/kv_pipe/pynccl_pipe.py | 285 ----------------- 13 files changed, 1624 deletions(-) delete mode 100644 tests/kv_transfer/test_lookup_buffer.py delete mode 100644 tests/kv_transfer/test_lookup_buffer.sh delete mode 100644 tests/kv_transfer/test_module.py delete mode 100644 tests/kv_transfer/test_send_recv.py delete mode 100644 tests/kv_transfer/test_send_recv.sh delete mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py delete mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/base.py delete mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py delete mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py delete mode 100644 vllm/distributed/kv_transfer/kv_pipe/__init__.py delete mode 100644 vllm/distributed/kv_transfer/kv_pipe/base.py delete mode 100644 vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py delete mode 100644 vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py deleted file mode 100644 index a61ccef70062..000000000000 --- a/tests/kv_transfer/test_lookup_buffer.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import random - -import torch -from tqdm import tqdm - -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer -from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe - -# TODO: the test depends on a lot of fields in the current implementation. -# We should have standard interface instead direct field access - - -def test_run(my_rank, buffer, device): - # buffer should be empty in the beginning - if my_rank == 0: - assert buffer.buffer_size == 0 - assert len(buffer.buffer) == 0 - - print(f"My rank: {my_rank}, device: {device}") - - # insert - tokens = torch.tensor([1, 2, 3]).to(device) - roi = tokens > 0 - if my_rank == 0: - key = 2.0 * torch.ones([5, 6]).to(device) - value = 3.0 * torch.ones([5, 6]).to(device) - - placeholder = torch.tensor([1]).to(device) - - buffer.insert(tokens, roi, key, value, placeholder) - - torch.distributed.barrier() - - # drop_select - if my_rank == 1: - tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) - assert torch.allclose(tokens, tok) - assert torch.allclose(roi, roi_) - assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device)) - assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device)) - torch.distributed.barrier() - - if my_rank == 0: - assert buffer.buffer_size == 0 - assert len(buffer.buffer) == 0 - - print(f"My rank: {my_rank}, Test run passed!") - - -def stress_test(my_rank, buf, device): - torch.distributed.barrier() - torch.manual_seed(100) - - reqs = [ - ( - torch.rand(100).to(device), # tokens - torch.ones(100).bool().to(device), # roi - torch.rand(100).to(device), # key - torch.rand(100).to(device), # value - torch.rand(100).to(device), # hidden - ) - for i in tqdm(range(200)) - ] - - random.seed(my_rank) - random.shuffle(reqs) - - torch.distributed.barrier() - - n = 0 - - # the buffer size can only store 100 reqs - # so the sender will occasionally block to wait for the receiver. - for req in tqdm(reqs): - if my_rank == 0: - buf.insert(*req) - else: - tok, roi, k, v, h = req - tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi) - - if tok_ is None: - assert roi_ is None - assert k_ is None - assert v_ is None - assert h_ is None - n += 1 - else: - assert torch.allclose(tok, tok_) - assert torch.allclose(roi, roi_) - assert torch.allclose(k, k_) - assert torch.allclose(v, v_) - assert torch.allclose(h, h_) - print(f"Rank {my_rank} done") - torch.distributed.barrier() - - if my_rank == 0: - x = torch.tensor([0]) - torch.distributed.recv(x, 1) - # the # of None received is the kv that are not selected - assert x.item() == len(buf.buffer) - # and the size of the buffer should be 2000 * buffer len - print(buf.buffer_size) - assert buf.buffer_size == 1700 * len(buf.buffer) - else: - torch.distributed.send(torch.tensor([n]), 0) - - print(f"My rank: {my_rank}, Passed stress test!") - - -if __name__ == "__main__": - my_rank = int(os.environ["RANK"]) - - torch.distributed.init_process_group( - backend="gloo", - init_method="tcp://localhost:12398", - world_size=2, - rank=my_rank, - ) - - print(f"initialized! My rank is {my_rank}") - - config = KVTransferConfig( - kv_connector="P2pNcclConnector", - kv_buffer_device="cuda", - kv_buffer_size=1e9, - kv_rank=my_rank, - kv_role="kv_both", # this arg doesn't matter in this test - kv_parallel_size=2, - kv_ip="127.0.0.1", - kv_port=12345, - ) - - data_pipe = PyNcclPipe( - local_rank=my_rank, - config=config, - device="cuda", - port_offset=0, - ) - cpu_pipe = PyNcclPipe( - local_rank=my_rank, - config=config, - device="cpu", - port_offset=1, - ) - - buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000) - - test_run(my_rank, buffer, data_pipe.device) - - stress_test(my_rank, buffer, data_pipe.device) - - buffer.close() - data_pipe.close() - cpu_pipe.close() - print("Done") diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh deleted file mode 100644 index f2aeaee9ca6d..000000000000 --- a/tests/kv_transfer/test_lookup_buffer.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -RANK=0 python3 test_lookup_buffer.py & -PID0=$! -RANK=1 python3 test_lookup_buffer.py & -PID1=$! - -wait $PID0 -wait $PID1 diff --git a/tests/kv_transfer/test_module.py b/tests/kv_transfer/test_module.py deleted file mode 100644 index b9a28e4bceb7..000000000000 --- a/tests/kv_transfer/test_module.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import subprocess -import sys - -import pytest -import torch - - -def run_python_script(script_name, timeout): - script_name = f"kv_transfer/{script_name}" - try: - # Start both processes asynchronously using Popen - process0 = subprocess.Popen( - [sys.executable, script_name], - env={"RANK": "0"}, # Set the RANK environment variable for process 0 - stdout=sys.stdout, # Pipe stdout to current stdout - stderr=sys.stderr, # Pipe stderr to current stderr - ) - - process1 = subprocess.Popen( - [sys.executable, script_name], - env={"RANK": "1"}, # Set the RANK environment variable for process 1 - stdout=sys.stdout, # Pipe stdout to current stdout - stderr=sys.stderr, # Pipe stderr to current stderr - ) - - # Wait for both processes to complete, with a timeout - process0.wait(timeout=timeout) - process1.wait(timeout=timeout) - - # Check the return status of both processes - if process0.returncode != 0: - pytest.fail(f"Test {script_name} failed for RANK=0, {process0.returncode}") - if process1.returncode != 0: - pytest.fail(f"Test {script_name} failed for RANK=1, {process1.returncode}") - - except subprocess.TimeoutExpired: - # If either process times out, terminate both and fail the test - process0.terminate() - process1.terminate() - pytest.fail(f"Test {script_name} timed out") - except Exception as e: - pytest.fail(f"Test {script_name} failed with error: {str(e)}") - - -# Define the test cases using pytest's parametrize -@pytest.mark.parametrize( - "script_name,timeout", - [ - ("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout - ("test_send_recv.py", 120), # First test case with a 120-second timeout - ], -) -def test_run_python_script(script_name, timeout): - # Check the number of GPUs - if torch.cuda.device_count() < 2: - pytest.skip(f"Skipping test {script_name} because <2 GPUs are available") - - # Run the test if there are at least 2 GPUs - run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py deleted file mode 100644 index 5762224eff76..000000000000 --- a/tests/kv_transfer/test_send_recv.py +++ /dev/null @@ -1,154 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import time - -import torch -from tqdm import tqdm - -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe - - -def test_run(my_rank, pipe): - print(f"rank {my_rank} test_run starts....") - # test run - x = torch.tensor([1]).to(pipe.device) - y = torch.tensor([[2.0, 3.0, 4.0, 8.0]]).to(pipe.device) - if my_rank == 0: - pipe.send_tensor(x) - print(f"rank {my_rank} sent tensor x") - pipe.send_tensor(y) - print(f"rank {my_rank} sent tensor y") - x2 = pipe.recv_tensor() - print(f"rank {my_rank} received x2 = ", x2) - y2 = pipe.recv_tensor() - print(f"rank {my_rank} received y2 = ", y2) - - else: - x2 = pipe.recv_tensor() - print(f"rank {my_rank} received x2 = ", x2) - y2 = pipe.recv_tensor() - print(f"rank {my_rank} received y2 = ", y2) - pipe.send_tensor(x) - print(f"rank {my_rank} sent tensor x") - pipe.send_tensor(y) - print(f"rank {my_rank} sent tensor y") - - assert torch.allclose(x, x2) - assert torch.allclose(y, y2) - - print(f"rank {my_rank} test_run passed!") - - -def stress_test(my_rank, pipe): - print(f"rank {my_rank} stress_test starts....") - - tensors: list[torch.Tensor] = [] - - torch.distributed.barrier() - torch.manual_seed(0) - - for i in tqdm(range(500)): - mean = torch.rand(1).item() * 100 - std = torch.rand(1).item() * 100 - size = torch.randint(900, 1000, (2,)) - x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device) - - # 5% probability of sending a None - if torch.rand(1).item() < 0.05: - tensors.append(None) - tensors.append(None) - tensors.append(None) - else: - tensors.append(x) - tensors.append(x.mean().unsqueeze(0)) - tensors.append(x.std().unsqueeze(0)) - - torch.distributed.barrier() - - for i in tqdm(range(500)): - if my_rank == int((i % 10) > 3): - pipe.send_tensor(tensors[3 * i]) - pipe.send_tensor(tensors[3 * i + 1]) - pipe.send_tensor(tensors[3 * i + 2]) - else: - x = pipe.recv_tensor() - mean = pipe.recv_tensor() - std = pipe.recv_tensor() - - if x is None: - assert mean is None - assert std is None - else: - assert torch.allclose(x, tensors[3 * i]) - assert x.mean() == mean[0] - assert x.std() == std[0] - - torch.distributed.barrier() - - -def latency_test(my_rank, pipe, nelement, ntensor): - latencies = [] - - torch.distributed.barrier() - - for i in tqdm(range(500)): - tensors = [] - - if my_rank == 0: - # create tensor - tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] - - torch.distributed.barrier() - - if my_rank == 0: - t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device) - for tensor in tensors: - pipe.send_tensor(tensor) - pipe.send_tensor(t) - else: - for _ in range(ntensor): - pipe.recv_tensor() - t = pipe.recv_tensor() - latencies.append(time.time() - t.item()) - - torch.distributed.barrier() - - print("Latency test passed.") - print("Latency:", torch.tensor(latencies).mean().item() * 1000, "ms") - - -if __name__ == "__main__": - my_rank = int(os.environ["RANK"]) - - torch.distributed.init_process_group( - backend="gloo", - init_method="tcp://localhost:12398", - world_size=2, - rank=my_rank, - ) - - config = KVTransferConfig( - kv_connector="P2pNcclConnector", - kv_buffer_device="cuda", - kv_buffer_size=1e9, - kv_rank=my_rank, - kv_role="kv_both", # this arg doesn't matter in this test - kv_parallel_size=2, - kv_ip="127.0.0.1", - kv_port=12345, - ) - - pipe = PyNcclPipe( - local_rank=my_rank, - config=config, - ) - - test_run(my_rank, pipe) - - stress_test(my_rank, pipe) - - # Use this function if you want to test the latency of pipe impl. - # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh deleted file mode 100644 index 54e060480684..000000000000 --- a/tests/kv_transfer/test_send_recv.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -RANK=0 python3 test_send_recv.py & -PID0=$! -RANK=1 python3 test_send_recv.py & -PID1=$! - -wait $PID0 -wait $PID1 diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py deleted file mode 100644 index f48d03d0b0cd..000000000000 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This file contains a new class `KVLookupBufferBase` that allows developers to -think of KV cache operations as inserting new KV cache entries (`insert`) -into the lookup buffer and querying existing KV caches (`drop_select`) -from the lookup buffer. - -This file also contains a new class `KVStoreBufferBase` that allows developers -to manage the KVCache buffer as a simple key-value storage buffer with basic -put/get operations. - -These classes above are abstracted behind class `KVCacheBufferBase`. -""" - -from abc import ABC, abstractmethod - -import torch - - -class KVCacheBufferBase(ABC): - """ - Abstract base class for a KVCache buffer. - """ - - @abstractmethod - def close(self) -> None: - """Close the buffer and release resources. - - This method is responsible for cleaning up resources related to the - KVCache buffer when it is no longer needed. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - -class KVLookupBufferBase(KVCacheBufferBase): - """ - Abstract base class for a KVCache lookup buffer. - - This class provides an abstraction for a key-value (KV) cache lookup buffer. - - The key of the lookup buffer: - - input_tokens: token IDs of the request - - roi: a binary mask on top of input_tokens. - - Purpose of roi: Since KV cache may only be available for a subset of - tokens in the input (for example, when vLLM is connected to an external - KV cache service), roi specifies the subset of tokens that the KV cache - is associated with. - - NOTE: roi can be further extended to describe which part of KV the - current process is holding (each process may only hold a part of KV - due to TP and PP). This is not implemented for now. - - The value of the lookup buffer: - - key: the key tensor in the KV cache - - value: the value tensor in the KV cache - - hidden: the final hidden state generated by model forwarding. This allows - vLLM to bypass further model forwarding by transmitting the hidden state. - """ - - @abstractmethod - def insert( - self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - hidden: torch.Tensor, - ) -> None: - """Insert into the lookup buffer. - - The functionality is similar to the following python statement - ``` - buffer[input_tokens, roi] = [key, value, hidden] - ``` - - FIXME: in the future, we should only have two arguments, key and value, - where key is a tensor dict and value is a tensor dict. - - FIXME: we should transmit both sampler outputs and the hidden states. - - Args: - input_tokens (torch.Tensor): token IDs. - roi (torch.Tensor): A binary mask on top of the input tokens - key (torch.Tensor): The key tensor in the KV cache. - value (torch.Tensor): The value tensor in the KV cache. - hidden (torch.Tensor): The final hidden state tensor generated - during model forwarding to bypass model - forwarding. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def drop_select( - self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None - ) -> list[torch.Tensor | None]: - """Select and *drop* KV cache entries from the lookup buffer. - - The functionality is similar to the following python statements - ``` - ret = buffer.pop(input_tokens, roi) - return ret - ``` - - If `input_tokens` and `roi` is `None`, it means selecting any of the - KV caches in the buffer, return, and remove it from the buffer, useful - when offloading KV cache to KV cache storage service. - - Args: - input_tokens (torch.Tensor): token IDs. - roi (torch.Tensor): A binary mask on top of the input tokens - - Returns: - list[Optional[torch.Tensor]]: A list of tensors. Can be None. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - -class KVStoreBufferBase(KVCacheBufferBase): - """ - Abstract base class for a KVCache storage buffer with key-value semantics. - This class provides a simple key-value storage buffer abstract with basic - put/get operations, which enables flexible KVCache transfer granular - control. - - The functionality is similar to a distributed key-value store, where: - - Key: A unique string identifier for the cached entry - - Value: - - Tensor to be stored and retrieved - - None (indicating deletion or empty value) - """ - - @abstractmethod - def put( - self, - key: str, - value: torch.Tensor | None, - ) -> None: - """Store a key-value pair in the buffer. - - Args: - key (str): Unique identifier for a tensor, this tensor could be the - key cache tensor, value cache tensor, or hidden state tensor - generated during model forwarding. - - value (Optional[torch.Tensor]): Tensor to be stored. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def get( - self, - key: str, - ) -> torch.Tensor | None: - """Retrieve a value from the buffer by key. - - Args: - key (str): Unique identifier for a tensor, this tensor could be the - key cache tensor, value cache tensor, or hidden state tensor - generated during model forwarding. - - Returns: - Optional[torch.Tensor]: Stored tensor if exists, None otherwise. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py deleted file mode 100644 index 7861bea1f9c5..000000000000 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +++ /dev/null @@ -1,164 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This file contains a new class `MooncakeStore` that allows developers to -think of KV cache transfer operations as putting new KV cache entries -into a remote KVStore-based lookup buffer and getting existing KV caches -from this remote lookup buffer. -""" - -import json -import os -from dataclasses import dataclass - -import torch -from safetensors.torch import load as safetensors_load -from safetensors.torch import save as safetensors_save - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVStoreBufferBase -from vllm.logger import init_logger - -DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB - -logger = init_logger(__name__) - - -@dataclass -class MooncakeStoreConfig: - local_hostname: str - metadata_server: str - global_segment_size: int - local_buffer_size: int - protocol: str - device_name: str - master_server_address: str - - @staticmethod - def from_file(file_path: str) -> "MooncakeStoreConfig": - """Load the config from a JSON file.""" - with open(file_path) as fin: - config = json.load(fin) - return MooncakeStoreConfig( - local_hostname=config.get("local_hostname"), - metadata_server=config.get("metadata_server"), - global_segment_size=config.get( - "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE - ), - local_buffer_size=config.get( - "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE - ), - protocol=config.get("protocol", "tcp"), - device_name=config.get("device_name", ""), - master_server_address=config.get("master_server_address"), - ) - - @staticmethod - def load_from_env() -> "MooncakeStoreConfig": - """Load config from a file specified in the environment variable.""" - config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") - if config_file_path is None: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." - ) - return MooncakeStoreConfig.from_file(config_file_path) - - -class MooncakeStore(KVStoreBufferBase): - def __init__( - self, - config: VllmConfig, - ): - try: - from mooncake.store import MooncakeDistributedStore - except ImportError as e: - raise ImportError( - "Please install mooncake by following the instructions at " - "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector." - ) from e - - try: - self.store = MooncakeDistributedStore() - self.config = MooncakeStoreConfig.load_from_env() - logger.info("Mooncake Configuration loaded successfully.") - - self.store.setup( - self.config.local_hostname, - self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address, - ) - - except ValueError as e: - logger.error("Configuration loading failed: %s", e) - raise - except Exception as exc: - logger.error("An error occurred while loading the configuration: %s", exc) - raise - - def close(self): - # MooncakeDistributedStore will automatically call the destructor, so - # it is unnecessary to close it manually. - pass - - def put( - self, - key: str, - value: torch.Tensor | None, - ) -> None: - # A message queue needs to be introduced before making it asynchronous. - if value is not None: - self._put_impl(key, value) - - def get( - self, - key: str, - ) -> torch.Tensor | None: - # A message queue needs to be introduced before making it asynchronous. - value = self._get_impl(key) - return value - - def _put_impl( - self, - key: str, - value: torch.Tensor, - ) -> None: - """Put KVCache to Mooncake Store""" - device_id = value.device.index if value.device.type == "cuda" else -1 - device_tensor = torch.tensor(device_id, dtype=torch.int32) - value_bytes = safetensors_save({"tensor": value, "device_id": device_tensor}) - try: - self.store.put(key, value_bytes) - except TypeError as err: - logger.error("Failed to put value into Mooncake Store: %s", err) - raise TypeError("Mooncake Store Put Type Error.") from err - - def _get_impl( - self, - key: str, - ) -> torch.Tensor | None: - """Get KVCache from Mooncake Store""" - try: - data = self.store.get(key) - except TypeError as err: - logger.error("Failed to get value from Mooncake Store: %s", err) - raise TypeError("Mooncake Store Get Type Error.") from err - - if data: - loaded_tensors = safetensors_load(data) - tensor = loaded_tensors["tensor"] - device_id_tensor = loaded_tensors["device_id"] - device_id = int(device_id_tensor.item()) - device = ( - torch.device("cuda", device_id) - if device_id >= 0 - else torch.device("cpu") - ) - return tensor.to(device) - - return None diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py deleted file mode 100644 index f046a349874e..000000000000 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ /dev/null @@ -1,242 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Implements a distributed key-value (KV) cache transfer mechanism. - -Key Features: -- Distributed KV cache transmission using PyNccl pipes. -- Non-blocking `insert`, blocking `drop_select`. -- Use CPU signal pipe to avoid racing condition -- Handles buffer size constraints and provide backpressure mechanism to - stop the prefill instance when the decode instance is slow. -""" - -import threading -from collections import deque - -import torch - -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class SimpleBuffer(KVLookupBufferBase): - def __init__( - self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float - ): - """ - signal_pipe: on CPU - - NOTE: on-device recv will block all threads in the process, making the - KV cache producer unable to listen to new request while transmitting - KV cache. Luckily CPU recv only blocks the current thread so we use - CPU recv to listen to new request. - - data_pipe: on device (e.g. GPU) - """ - - self.buffer: deque[list[torch.Tensor]] = deque() - - self.buffer_size = 0 - self.buffer_size_threshold = buffer_size_thresh - self.buffer_cv = threading.Condition() - self.signal_pipe = signal_pipe - self.data_pipe = data_pipe - self.request_handling_thread: threading.Thread | None = None - - self.normal_signal = torch.tensor([0], device="cpu") - self.end_signal = None - - def _matches( - self, - tokens_roi_sender: list[torch.Tensor], - tokens_roi_recver: list[torch.Tensor], - ): - # tokens_roi_sender: tokens and roi of the producer (in the buffer) - # tokens_roi_recver: tokens and roi of the consumer (query) - - tokens_sender = tokens_roi_sender[0] - tokens_recver = tokens_roi_recver[0] - roi_sender = tokens_roi_sender[1] - roi_recver = tokens_roi_recver[1] - - if tokens_recver is None: - # consumer sends an empty request - # semantics: DROP SELECT * LIMIT 1 - # so any of the data in the buffer can be drop-selected - return True - - # Assuming that roi is a binary mask on tokens - tokens_sender = tokens_sender[roi_sender] - tokens_recver = tokens_recver[roi_recver] - - # simple common prefix matching - min_length = min(len(tokens_sender), len(tokens_recver)) - if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]): - return min_length - - return 0 - - def _send_tensor_and_dec_size(self, tensor: torch.Tensor | None) -> None: - assert tensor is not None, "Use self.data_pipe.send(None) instead" - self.buffer_size -= tensor.element_size() * tensor.numel() - if tensor.dtype == torch.bool: - tensor = tensor.float() - self.data_pipe.send_tensor(tensor) - - def _get_element_size(self, data: list | torch.Tensor | None): - if isinstance(data, torch.Tensor): - return data.element_size() * data.numel() - if not data: - # cannot perform `not data` on a tensor - # so this check needs to go after the check above - return 0 - - raise AssertionError(f"Unknown data type {type(data)}") - - def _add_to_buffer( - self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - hidden: torch.Tensor, - ): - if isinstance(input_tokens, torch.Tensor): - input_tokens = input_tokens.clone() - if isinstance(roi, torch.Tensor): - roi = roi.clone() - if isinstance(key, torch.Tensor): - key = key.clone() - if isinstance(value, torch.Tensor): - value = value.clone() - if isinstance(hidden, torch.Tensor): - hidden = hidden.clone() - - buffer_item = [input_tokens, roi, key, value, hidden] - data_size = sum([self._get_element_size(data) for data in buffer_item]) - - with self.buffer_cv: - if self.buffer_size + data_size > self.buffer_size_threshold: - # log outside the while loop to avoid this message being logged - # repeatedly. - logger.debug("KV transfer buffer is full. Handling...") - while self.buffer_size + data_size > self.buffer_size_threshold: - self.buffer_cv.wait() - - self.buffer_size += data_size - self.buffer.append(buffer_item) - self.buffer_cv.notify() - - def _is_end_signal(self, signal): - return signal is None - - def drop_select_handler(self): - try: - while True: - signal = self.signal_pipe.recv_tensor() - if self._is_end_signal(signal): - logger.info("Received end signal!") - break - - input_tokens = self.data_pipe.recv_tensor() - - roi = self.data_pipe.recv_tensor() - assert roi is not None, ( - "Please provide the roi when sending drop-select request" - ) - roi = roi > 0.5 - tokens_roi_recver = [input_tokens, roi] - - def is_buffer_available( - tokens_roi_recver: list[torch.Tensor], - ) -> bool: - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - for _ in range(len(self.buffer)): - if self._matches(self.buffer[0], tokens_roi_recver) > 0: - return True - # rotate the element we just accessed to the end - self.buffer.rotate(-1) - return False - - with self.buffer_cv: - while not is_buffer_available(tokens_roi_recver): - logger.debug("KV transfer buffer is not available. Waiting...") - self.buffer_cv.wait() - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - self.buffer_cv.notify() - - except RuntimeError as e: - if "Connection closed by peer" not in str(e): - raise e - - logger.debug("Closing drop_select_handler") - - def drop_select( - self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None - ) -> list[torch.Tensor | None]: - assert self.request_handling_thread is None, ( - "drop_select should be called by the KV cache consumer " - "(e.g. the decode vLLM instance)" - ) - - if isinstance(input_tokens, torch.Tensor): - input_tokens = input_tokens.clone() - if isinstance(roi, torch.Tensor): - roi = roi.clone().float() - - self.signal_pipe.send_tensor(self.normal_signal) - self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi) - - input_tokens = self.data_pipe.recv_tensor() - roi = self.data_pipe.recv_tensor() - if roi is not None: - # convert from float tensor to bool tensor - # as PyNccl does not support sending bool tensor - roi = roi > 0.5 - key = self.data_pipe.recv_tensor() - value = self.data_pipe.recv_tensor() - hidden = self.data_pipe.recv_tensor() - - return [input_tokens, roi, key, value, hidden] - - def insert( - self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - hidden: torch.Tensor, - ) -> None: - self._add_to_buffer(input_tokens, roi, key, value, hidden) - - # when calling the insert, the current process is a sender - # need to launch the request handler and start listening to request. - if self.request_handling_thread is None: - self.request_handling_thread = threading.Thread( - target=self.drop_select_handler - ) - self.request_handling_thread.start() - - def close(self): - if ( - hasattr(self, "request_handling_thread") - and self.request_handling_thread is not None - ): - self.request_handling_thread.join() - - else: - # TODO: have a explicit close signal and have a explicit way to - # check if it's requester - self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/vllm/distributed/kv_transfer/kv_pipe/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py deleted file mode 100644 index 1fe7a90e9a71..000000000000 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This file defines an interface `KVPipeBase` -that provides an abstraction for sending and receiving tensors, or None, via -distributed communications. - -All classes instantiated from this interface are assumed to be a FIFO pipe. - -If your distributed communication platform already supports key-value lookup, -you can bypass this interface and directly start from `kv_lookup_buffer`. -""" - -from abc import ABC, abstractmethod - -import torch - - -class KVPipeBase(ABC): - """ - This class provides an interface for sending and receiving tensors, or - None, by distributed communications. - """ - - @abstractmethod - def send_tensor(self, tensor: torch.Tensor | None) -> None: - """Send a tensor, or None, via the pipe. - - Need to support sending None -- important for error handling. - - TODO: add a `key` argument so that we can use traditional - key-value database as the distributed communication mechanism behind - the pipe. - - Args: - tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def recv_tensor(self) -> torch.Tensor | None: - """Receive a tensor (can be None) from the pipeline. - - Returns: - Optional[torch.Tensor]: The tensor received from the pipeline. Can - be None. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def close(self) -> None: - """Close the pipeline and release resources. - - This method is responsible for closing the communication pipeline - and releasing any resources associated with it. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py deleted file mode 100644 index 542dde09abad..000000000000 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ /dev/null @@ -1,295 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import json -import os -import struct -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass - -import torch -import zmq -from safetensors.torch import load as safetensors_load -from safetensors.torch import save as safetensors_save - -from vllm.config.kv_transfer import KVTransferConfig -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from vllm.logger import init_logger -from vllm.utils.network_utils import join_host_port, make_zmq_path, split_host_port - -logger = init_logger(__name__) -NONE_INT = -150886311 - - -@dataclass -class MooncakeTransferEngineConfig: - prefill_url: str - decode_url: str - metadata_backend: str | None - metadata_server: str - protocol: str - device_name: str - - @staticmethod - def from_file(file_path: str) -> "MooncakeTransferEngineConfig": - """Load the config from a JSON file.""" - with open(file_path) as fin: - config = json.load(fin) - return MooncakeTransferEngineConfig( - prefill_url=config.get("prefill_url"), - decode_url=config.get("decode_url"), - metadata_backend=config.get("metadata_backend", None), - metadata_server=config.get("metadata_server"), - protocol=config.get("protocol", "tcp"), - device_name=config.get("device_name", ""), - ) - - @staticmethod - def load_from_env() -> "MooncakeTransferEngineConfig": - """Load config from a file specified in the environment variable.""" - config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") - if config_file_path is None: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." - ) - return MooncakeTransferEngineConfig.from_file(config_file_path) - - -class MooncakeTransferEngine: - """Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ.""" - - def __init__(self, kv_rank: int, local_rank: int): - try: - from mooncake.engine import TransferEngine - except ImportError as e: - raise ImportError( - "Please install mooncake by following the instructions at " - "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector." - ) from e - - self.engine = TransferEngine() - self.local_rank = local_rank - - try: - self.config = MooncakeTransferEngineConfig.load_from_env() - logger.info("Mooncake Configuration loaded successfully.") - except ValueError as e: - logger.error(e) - raise - except Exception as exc: - logger.error("An error occurred while loading the configuration: %s", exc) - raise - prefill_host, base_prefill_port = split_host_port(self.config.prefill_url) - decode_host, base_decode_port = split_host_port(self.config.decode_url) - - # Avoid ports conflict when running prefill and decode on the same node - if prefill_host == decode_host and base_prefill_port == base_decode_port: - base_decode_port = base_decode_port + 100 - - prefill_port = base_prefill_port + self.local_rank - decode_port = base_decode_port + self.local_rank - self.prefill_url = join_host_port(prefill_host, prefill_port) - self.decode_url = join_host_port(decode_host, decode_port) - - self.initialize( - self.prefill_url if kv_rank == 0 else self.decode_url, - self.config.metadata_server, - self.config.protocol, - self.config.device_name, - self.config.metadata_backend, - ) - - self.remote_url = self.decode_url if kv_rank == 0 else self.prefill_url - - # Initialize ZeroMQ context and sockets - self.context = zmq.Context() # type: ignore[attr-defined] - self.sender_socket = self.context.socket(zmq.constants.PUSH) - self.receiver_socket = self.context.socket(zmq.constants.PULL) - self.sender_ack = self.context.socket(zmq.constants.PULL) - self.receiver_ack = self.context.socket(zmq.constants.PUSH) - - self.buffer_cleaner = ThreadPoolExecutor(max_workers=1) - self._setup_metadata_sockets( - kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port - ) - - def _setup_metadata_sockets( - self, kv_rank: int, p_host: str, p_port: int, d_host: str, d_port: int - ) -> None: - """Set up ZeroMQ sockets for sending and receiving data.""" - # Offsets < 8 are left for initialization in case tp and pp are enabled - p_rank_offset = p_port + 8 + self.local_rank * 2 - d_rank_offset = d_port + 8 + self.local_rank * 2 - if kv_rank == 0: - self.sender_socket.bind(make_zmq_path("tcp", p_host, p_rank_offset + 1)) - self.receiver_socket.connect( - make_zmq_path("tcp", d_host, d_rank_offset + 1) - ) - self.sender_ack.connect(make_zmq_path("tcp", d_host, d_rank_offset + 2)) - self.receiver_ack.bind(make_zmq_path("tcp", p_host, p_rank_offset + 2)) - else: - self.receiver_socket.connect( - make_zmq_path("tcp", p_host, p_rank_offset + 1) - ) - self.sender_socket.bind(make_zmq_path("tcp", d_host, d_rank_offset + 1)) - self.receiver_ack.bind(make_zmq_path("tcp", d_host, d_rank_offset + 2)) - self.sender_ack.connect(make_zmq_path("tcp", p_host, p_rank_offset + 2)) - - def initialize( - self, - local_hostname: str, - metadata_server: str, - protocol: str, - device_name: str, - metadata_backend: str | None, - ) -> None: - """Initialize the mooncake instance.""" - if metadata_backend is None: - self.engine.initialize( - local_hostname, metadata_server, protocol, device_name - ) - else: - supported_backend = ["etcd", "redis"] - metadata_backend = metadata_backend.lower() - if metadata_backend not in supported_backend: - raise ValueError( - "Mooncake Configuration error. `metadata_backend`" - f" should be one of {supported_backend}." - ) - - self.engine.initialize_ext( - local_hostname, metadata_server, protocol, device_name, metadata_backend - ) - - def allocate_managed_buffer(self, length: int) -> int: - """Allocate a managed buffer of the specified length.""" - ret = self.engine.allocate_managed_buffer(length) - if ret <= 0: - logger.error("Allocation Return Error") - raise Exception("Allocation Return Error") - return ret - - def free_managed_buffer(self, buffer: int, length: int) -> int: - """Free a previously allocated managed buffer.""" - return self.engine.free_managed_buffer(buffer, length) - - def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int: - """Synchronously transfer data to the specified address.""" - ret = self.engine.transfer_sync_read( - self.remote_url, buffer, peer_buffer_address, length - ) - if ret < 0: - logger.error("Transfer Return Error") - raise Exception("Transfer Return Error") - return ret - - def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int: - """Write bytes to the allocated buffer.""" - return self.engine.write_bytes_to_buffer(buffer, user_data, length) - - def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: - """Read bytes from the allocated buffer.""" - return self.engine.read_bytes_from_buffer(buffer, length) - - def wait_for_ack(self, src_ptr: int, length: int) -> None: - """Asynchronously wait for ACK from the receiver.""" - ack = self.sender_ack.recv() - if ack != b"ACK": - logger.error("Failed to receive ACK from the receiver") - - self.free_managed_buffer(src_ptr, length) - - def send_bytes(self, user_data: bytes) -> None: - """Send bytes to the remote process.""" - length = len(user_data) - src_ptr = self.allocate_managed_buffer(length) - self.write_bytes_to_buffer(src_ptr, user_data, length) - self.sender_socket.send_multipart( - [struct.pack("!Q", src_ptr), struct.pack("!Q", length)] - ) - self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) - - def recv_bytes(self) -> bytes: - """Receive bytes from the remote process.""" - data = self.receiver_socket.recv_multipart() - src_ptr = struct.unpack("!Q", data[0])[0] - length = struct.unpack("!Q", data[1])[0] - dst_ptr = self.allocate_managed_buffer(length) - self.transfer_sync(dst_ptr, src_ptr, length) - ret = self.read_bytes_from_buffer(dst_ptr, length) - - # Buffer cleanup - self.receiver_ack.send(b"ACK") - self.free_managed_buffer(dst_ptr, length) - - return ret - - -class MooncakePipe(KVPipeBase): - """MooncakeTransferEngine based Pipe implementation.""" - - def __init__( - self, local_rank: int, config: KVTransferConfig, device: str | None = None - ): - """Initialize the mooncake pipe and set related parameters.""" - self.config = config - self.local_rank = local_rank - self.kv_rank = self.config.kv_rank - assert self.kv_rank is not None - if device is None: - self.device = self._select_device(self.config.kv_buffer_device) - else: - self.device = self._select_device(device) - - self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank) - self.transport_thread: ThreadPoolExecutor | None = None - self.none_tensor = torch.tensor([NONE_INT], device=self.device) - - def _select_device(self, device: str) -> torch.device: - """Select available device (CUDA or CPU).""" - logger.info("Selecting device: %s", device) - if device == "cuda": - return torch.device(f"cuda:{self.local_rank}") - else: - return torch.device("cpu") - - def tensor_hash(self, tensor: torch.Tensor) -> int: - """Calculate the hash value of the tensor.""" - return hash(tensor.data_ptr()) - - def _send_impl(self, tensor: torch.Tensor) -> None: - """Implement the tensor sending logic using safetensors.""" - self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor})) - - def _recv_impl(self) -> torch.Tensor: - """Implement the tensor receiving logic using safetensors.""" - data = self.transfer_engine.recv_bytes() - return safetensors_load(data)["tensor"].to(self.device) - - def send_tensor(self, tensor: torch.Tensor | None) -> None: - """Send tensor to the target process.""" - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - tensor = tensor if tensor is not None else self.none_tensor - assert len(tensor.shape) > 0 - self.transport_thread.submit(self._send_impl, tensor) - - def recv_tensor(self) -> torch.Tensor | None: - """Receive tensor from other processes.""" - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - tensor = self.transport_thread.submit(self._recv_impl).result() - if tensor.numel() == 1 and tensor.item() == NONE_INT: - return None - else: - return tensor - - def close(self) -> None: - """Cleanup logic when closing the pipe.""" - self.transfer_engine.sender_socket.close() - self.transfer_engine.receiver_socket.close() - self.transfer_engine.sender_ack.close() - self.transfer_engine.receiver_ack.close() - self.transfer_engine.context.term() # Terminate the ZMQ context - logger.info("Closed the transfer engine and cleaned up resources.") diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py deleted file mode 100644 index 526c5cd1d527..000000000000 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ /dev/null @@ -1,285 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This module implements a PyNccl pipe for sending and receiving -Optional[torch.Tensor] between distributed ranks with advanced -communication features. - -Key Features: -- Supports sending and receiving tensors with metadata -- Handles both CUDA and CPU device communications -- Implements a non-blocking tensor transfer mechanism -- Manages buffer size and provides backpressure control -- Supports distributed process groups with configurable parameters -""" - -import threading -import time -from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor - -import torch - -from vllm.config.kv_transfer import KVTransferConfig -from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from vllm.distributed.utils import StatelessProcessGroup -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class BrokenPipeException(Exception): - def __init__(self, message): - self.message = message - super().__init__(self.message) - - -Metadata = dict[str, torch.Tensor | None] - - -class PyNcclPipe(KVPipeBase): - METADATA_LENGTH = 16 - MAX_TENSOR_DIMENSIONS = 14 - METADATA_DTYPE = torch.int64 - - def __init__( - self, - local_rank: int, - config: KVTransferConfig, - device: str | None = None, - port_offset: int = 0, - ): - self.config = config - self.local_rank = local_rank - self.kv_rank = self.config.kv_rank - assert self.kv_rank is not None - self.kv_parallel_size = self.config.kv_parallel_size - if device is None: - self.device = self._select_device(self.config.kv_buffer_device) - else: - self.device = self._select_device(device) - - # build distributed connection and send/recv implementation - store_timeout = self.config.get_from_extra_config("store_timeout", 300) - self.group = StatelessProcessGroup.create( - host=self.config.kv_ip, - port=self.config.kv_port + port_offset, - rank=self.kv_rank, - world_size=self.kv_parallel_size, - store_timeout=store_timeout, - ) - # add a barrier to make sure the connection is initiated properly - self.group.barrier() - impl = self._get_device_send_recv_impl(self.group) - self.device_send_func, self.device_recv_func = impl - # set target rank - self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size - self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size - - # transportation-related variables - self.transport_thread: ThreadPoolExecutor | None = None - self.buffer_size = 0 - self.buffer_size_lock = threading.Lock() - self.buffer_size_thresh = self.config.kv_buffer_size - - def _get_device_send_recv_impl( - self, group: StatelessProcessGroup - ) -> tuple[ - Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None] - ]: - send: Callable[[torch.Tensor, int], None] - recv: Callable[[torch.Tensor, int], None] - if self.device.type == "cuda": - # use PyNCCL for send / recv - comm = PyNcclCommunicator(group, device=self.local_rank) - comm.disabled = False - send, recv = comm.send, comm.recv # type: ignore - else: - # This send / recv implementation here is NOT intended to transfer - # KV caches (and should NOT be repurposed to transfer KV caches). - # Currently it is only used to transmit control-plane messages - # for PyNcclBuffer. - send = group.send_obj - - def my_recv(x, src): - x[...] = group.recv_obj(src) - - recv = my_recv - - return send, recv - - def _select_device(self, device: str): - logger.info("Selecting device: %s", device) - if device == "cuda": - return torch.device(f"cuda:{self.local_rank}") - else: - return torch.device("cpu") - - def _make_metadata(self, tensor: torch.Tensor | None) -> Metadata: - """ - Create the metadata as a dictionary based on the input tensor. - - Args: - tensor: The input tensor or None if no tensor is provided. - - Returns: - metadata: A dictionary with the following keys: - - "dtype": The data type of the tensor or None. - - "shape": The shape of the tensor or None. - """ - if tensor is None: - return {"dtype": None, "shape": None} - else: - return {"dtype": tensor.dtype, "shape": tensor.shape} - - def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: - """ - Create a buffer to receive the tensor based on the provided metadata. - - Args: - metadata: A dictionary with keys "dtype" and "shape", - describing the tensor's data type and shape. - - Returns: - buffer: A tensor of the specified type and shape, - allocated on `self.device`. - """ - return torch.empty( - metadata["shape"], dtype=metadata["dtype"], device=self.device - ) - - def _send_metadata(self, metadata: Metadata): - """ - Send the metadata dictionary to the target rank. - - Args: - metadata: A dictionary with keys "dtype" and "shape". - """ - self.group.send_obj(metadata, self.target_rank_for_send) - - def _recv_metadata(self) -> Metadata: - """ - Receive the metadata dictionary from the target rank. - - Returns: - metadata: A dictionary with keys "dtype" and "shape" - describing the tensor. - """ - return self.group.recv_obj(self.target_rank_for_recv) - - def _send_impl(self, tensor: torch.Tensor | None) -> None: - """ - The actual implementation of sending the tensor and its metadata to the - target rank. - - Args: - tensor: The input tensor to be sent, or `None` if no tensor is - being sent. - """ - metadata = self._make_metadata(tensor) - self._send_metadata(metadata) - if tensor is not None: - self.device_send_func(tensor.to(self.device), self.target_rank_for_send) - - def _recv_impl(self) -> torch.Tensor | None: - """ - The actual implementation of receiving a tensor and its metadata from - the target rank. - - Returns: - buffer: The received tensor, or `None` if no tensor is received. - """ - metadata = self._recv_metadata() - if metadata["dtype"] is None: - return None - buffer = self._prepare_recv_buffer(metadata) - self.device_recv_func(buffer, self.target_rank_for_recv) - - return buffer - - def send_tensor_wrapper( - self, tensor: torch.Tensor | None, tensor_size: int - ) -> None: - """ - Wrapper for _send_impl to handle exceptions and update buffer size. - """ - try: - self._send_impl(tensor) - - with self.buffer_size_lock: - self.buffer_size -= tensor_size - except Exception as e: - logger.error( - "[rank%d]: Exception when trying to send %s, msg: %s", - torch.distributed.get_rank(), - str(tensor), - str(e), - ) - import traceback - - traceback.print_exc() - - def block_if_full(self): - """ - Block the current thread if the buffer size is larger than the - threshold. - """ - while self.buffer_size > self.buffer_size_thresh: - logger.debug("KV cache transfer pipe is full. Waiting...") - time.sleep(0.05) - - def send_tensor(self, tensor: torch.Tensor | None) -> None: - """ - Sends a tensor and its metadata to the destination rank in a - non-blocking way. - - Args: - tensor: The tensor to send, or `None` if no tensor is being sent. - """ - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - if tensor is not None: - tensor_size = tensor.element_size() * tensor.numel() - else: - tensor_size = 0 - - self.block_if_full() - - with self.buffer_size_lock: - self.buffer_size += tensor_size - - self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size) - - def recv_tensor(self) -> torch.Tensor | None: - """ - Receives a tensor and its metadata from the source rank. Blocking call. - - Returns: - The received tensor, or `None` if no tensor is received. - """ - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - future = self.transport_thread.submit(self._recv_impl) - - try: - tensor = future.result() - except Exception as e: - logger.error("Encountering exception in KV receiving thread") - logger.error("%s", e) - logger.error("My device: %s", self.device) - import traceback - - traceback.print_exc() - raise e - - return tensor - - def close(self): - """ - Close the pipe and release associated resources. - """ - if hasattr(self, "transport_thread") and self.transport_thread is not None: - self.transport_thread.shutdown()