From 4e212956e3a2f62eb94bedda73130614fb9eeda0 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Tue, 18 Nov 2025 05:51:11 +0000 Subject: [PATCH 01/19] separate cpu_backend; scheduler has a offload_manager who makes all the decisions; worker has a regular kv store Signed-off-by: Juncheng Gu --- .../distributed/cpu_chunk_manager.py | 223 ++++++++ .../distributed/local_cpu_backend.py | 154 ++---- .../distributed/tpu_connector_local.py | 480 ++++++++++-------- 3 files changed, 521 insertions(+), 336 deletions(-) create mode 100644 tpu_inference/distributed/cpu_chunk_manager.py diff --git a/tpu_inference/distributed/cpu_chunk_manager.py b/tpu_inference/distributed/cpu_chunk_manager.py new file mode 100644 index 000000000..ad36d20c8 --- /dev/null +++ b/tpu_inference/distributed/cpu_chunk_manager.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import OrderedDict +from dataclasses import dataclass +from typing import Literal, Tuple + +from vllm.v1.core.kv_cache_utils import BlockHash + +from tpu_inference.logger import init_logger + +logger = init_logger(__name__) + +GB = 1024**3 +DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB + +ChunkHash = BlockHash + + +@dataclass +class CPUChunk: + chunk_id: int + ref_cnt: int = -1 + _chunk_hash: ChunkHash | None = None + + @property + def is_ready_to_load(self): + return self.ref_cnt >= 0 + + @property + def is_ready_to_evict(self): + return self.ref_cnt <= 0 + + @property + def is_in_use(self): + return self.ref_cnt >= 1 + + @property + def chunk_hash(self): + return self._chunk_hash + + def touch(self): + self.ref_cnt += 1 + + def untouch(self): + self.ref_cnt -= 1 + + def reset(self): + self._chunk_hash = None + self.ref_cnt = -1 + + +class CPUChunkPool: + + def __init__(self, num_chunks: int): + self.num_chunks: int = num_chunks + self._num_allocated_chunks: int = 0 + self.free_chunk_list: list[CPUChunk] = [ + CPUChunk(idx) for idx in range(num_chunks - 1, -1, -1) + ] + # {allocated_chunk_id: chunk_hash} + self.allocated_id_to_hash_map: dict[int, ChunkHash] = {} + + @property + def num_free_chunks(self): + return self.num_chunks - self._num_allocated_chunks + + @property + def num_allocated_chunks(self): + return self._num_allocated_chunks + + def allocate_chunks(self, chunk_hashes: list[ChunkHash]) -> list[CPUChunk]: + num_required_chunks = len(chunk_hashes) + if num_required_chunks > self.num_free_chunks: + raise ValueError( + f"Cannot get {num_required_chunks} free chunks from the pool") + + ret: list[CPUChunk] = [ + self.free_chunk_list.pop() for _ in range(num_required_chunks) + ] + for chunk, chunk_hash in zip(ret, chunk_hashes): + chunk._chunk_hash = chunk_hash + assert chunk.chunk_id not in self.allocated_id_to_hash_map + self.allocated_id_to_hash_map[chunk.chunk_id] = chunk_hash + + return ret + + def release_chunks(self, chunks: list[CPUChunk]): + for chunk in chunks: + if not chunk.is_ready_to_evict: + logger.warning(f" Chunk[{chunk.chunk_id}] is still in use.") + assert chunk.chunk_id in self.allocated_id_to_hash_map + self.allocated_id_to_hash_map.pop(chunk.chunk_id) + self.free_chunk_list.append(chunk) + chunk.reset() + self._num_allocated_chunks -= len(chunks) + + +class LRUOffloadingManager: + + def __init__(self, num_cpu_chunks: int): + self.num_chunks = num_cpu_chunks + self.chunk_pool = CPUChunkPool(self.num_chunks) + + self.cpu_cache: OrderedDict[ChunkHash, CPUChunk] = OrderedDict() + + # The cache is an OrderedDict for LRU behavior. + def lookup(self, chunk_hashes: list[ChunkHash]) -> int: + """_summary_ + return the number of cache hit starting from the first chunk + """ + hit_count = 0 + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache.get(chunk_hash) + if chunk is None or not chunk.is_ready_to_load: + break + hit_count += 1 + return hit_count + + def touch(self, chunk_hashes: list[ChunkHash]) -> int: + """ access chunks for both save / load; and move them to the end.""" + for chunk_hash in reversed(chunk_hashes): + if self.cpu_cache.get(chunk_hash): + self.cpu_cache.move_to_end(chunk_hash) + + def allocate_for_save( + self, chunk_hashes: list[ChunkHash] + ) -> Tuple[list[CPUChunk], list[int]] | None: + # filter out chunks that are already stored + num_chunks = len(chunk_hashes) + new_chunk_idxs = [ + i for i in range(num_chunks) + if chunk_hashes[i] not in self.cpu_cache + ] + + num_new_chunks = len(new_chunk_idxs) + if num_new_chunks == 0: + logger.info("No new chunks to allocate") + return None + num_chunks_to_evict = max( + 0, num_new_chunks - self.chunk_pool.num_free_chunks) + + # build list of chunks to evict / reuse + to_evict = [] + if num_chunks_to_evict > 0: + for chunk_hash, chunk in self.cpu_cache.items(): + if chunk.is_ready_to_evict: + to_evict.append(chunk_hash) + num_chunks_to_evict -= 1 + if num_chunks_to_evict == 0: + break + else: + # we could not evict enough chunks + return None + + # evict chunks + self.chunk_pool.release_chunks([ + self.cpu_cache.pop(evicting_chunk_hash) + for evicting_chunk_hash in to_evict + ]) + + new_chunk_hashes = [chunk_hashes[i] for i in new_chunk_idxs] + # allocate + try: + new_chunks = self.chunk_pool.allocate_chunks(new_chunk_hashes) + assert len(new_chunks) == len(new_chunk_hashes) + except Exception as e: + logger.warning(f" Failed to allocate {len(new_chunk_hashes)}: {e}") + # NOTE(jcgu): should we return None or something else? + return None + for chunk_hash, chunk in zip(new_chunk_hashes, new_chunks): + self.cpu_cache[chunk_hash] = chunk + # newly-allocated chunks, chunk-idx in the given chunk_hashes list + return new_chunks, new_chunk_idxs + + def prepare_load(self, chunk_hashes: list[ChunkHash]) -> list[CPUChunk]: + chunks = [] + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache[chunk_hash] + assert chunk.is_ready_to_load + chunk.touch() + chunks.append(chunk) + return chunks + + def complete_save(self, chunk_hashes: list[ChunkHash]) -> None: + """ After store completion, mark the chunk to be ready to load.""" + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache[chunk_hash] + assert not chunk.is_ready_to_load + # mark ready to load + chunk.touch() + assert chunk.is_ready_to_load + + def complete_load(self, chunk_hashes: list[ChunkHash]) -> None: + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache[chunk_hash] + assert chunk.is_in_use + chunk.untouch() + + def mark_completion(self, chunk_ids, operation: Literal['save', + 'load']) -> None: + chunk_hashes = [ + self.chunk_pool.allocated_id_to_hash_map[chunk_id] + for chunk_id in chunk_ids + ] + chunk_hashes = [] + unknown_chunk_ids = [] + for chunk_id in chunk_ids: + if chunk_id in self.chunk_pool.allocated_id_to_hash_map: + chunk_hashes.append( + self.chunk_pool.allocated_id_to_hash_map[chunk_id]) + else: + unknown_chunk_ids.append(chunk_id) + logger.warning( + f" Chunks[{unknown_chunk_ids}] are not found as allocated chunks in the pool." + ) + + if operation == 'save': + self.complete_save(chunk_hashes) + elif operation == 'load': + self.complete_load(chunk_hashes) + else: + raise ValueError(f"Unknown operation: {operation}") diff --git a/tpu_inference/distributed/local_cpu_backend.py b/tpu_inference/distributed/local_cpu_backend.py index 199382777..49a74623a 100644 --- a/tpu_inference/distributed/local_cpu_backend.py +++ b/tpu_inference/distributed/local_cpu_backend.py @@ -3,18 +3,16 @@ import os import sys -import threading from collections import OrderedDict -from typing import Any, List, Optional, Tuple +from typing import Any, Optional from tpu_inference.logger import init_logger -from .cache_util import CacheKey - logger = init_logger(__name__) GB = 1024**3 DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB +CpuChunkId = int # TODO(jcgu): creating independent cpu backends since scheduler & worker could be in different processes. @@ -30,31 +28,17 @@ class LocalCPUBackend: It implements an LRU (Least Recently Used) eviction policy with a maximum size limit and support for pinning cache entries to prevent eviction. """ - _instance: Optional["LocalCPUBackend"] = None - _initialized: bool = False - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super(LocalCPUBackend, cls).__new__(cls) - return cls._instance def __init__(self, max_cpu_cache_size_bytes: int = DEFAULT_CPU_CACHE_SIZE_BYTES): - if self._initialized: - return - - self.lock = threading.Lock() env_cache_size_gb = os.getenv("TPU_OFFLOAD_CPU_CACHE_SIZE_GB") self.max_cpu_cache_size_bytes = (int(env_cache_size_gb) * GB if env_cache_size_gb is not None else max_cpu_cache_size_bytes) # The cache is an OrderedDict for LRU behavior. - self.cache: OrderedDict[CacheKey, Any] = OrderedDict() + self.cache: OrderedDict[CpuChunkId, Any] = OrderedDict() self.current_size_bytes = 0 - # Use a dictionary for reference counting of pinned keys. - self.pin_counts: dict[CacheKey, int] = {} - self._initialized = True logger.info("Singleton LocalCPUBackend initialized." f"CPU cache size: {self.max_cpu_cache_size_bytes} bytes") @@ -71,108 +55,48 @@ def _get_value_size(self, value: Any) -> int: size_in_bytes = sys.getsizeof(value) return size_in_bytes - def add(self, key: CacheKey, value: Any): + def add(self, key: CpuChunkId, value: Any) -> bool: """ Adds a key-value pair to the cache. If the cache is full, it evicts the least recently used, unpinned entries until there is enough space. """ - with self.lock: - value_size = self._get_value_size(value) - # Do not add if the item itself is larger than the cache capacity. - if value_size > self.max_cpu_cache_size_bytes: - logger.warning( - f"Cannot add item of size {value_size} bytes to " - f"cache with capacity {self.max_cpu_cache_size_bytes} bytes." - ) - return - - # If key already exists, remove it to update its size and position. - if key in self.cache: - old_value = self.cache.pop(key) - self.current_size_bytes -= self._get_value_size(old_value) - - # Evict old, unpinned entries until there is enough space for the new item. - while self.current_size_bytes + value_size > self.max_cpu_cache_size_bytes: - evicted_key = None - # Find the first unpinned key from the LRU end of the cache. - for k in self.cache: - if k not in self.pin_counts: - evicted_key = k - break - - # If no unpinned key can be evicted, we cannot make space. - if evicted_key is None: - logger.warning( - "Cache is full of pinned items. Cannot add new key " - f"({key.chunk_hash}) until some are unpinned.") - # If we popped the key before, we need to decide what to do. - # For simplicity, we just won't add the new value. - return - - # Evict the found key. - evicted_value = self.cache.pop(evicted_key) - self.current_size_bytes -= self._get_value_size(evicted_value) - logger.info( - f"Evicted key {evicted_key.chunk_hash} to make space.") - - # Add the new item. - self.cache[key] = value - self.current_size_bytes += value_size - logger.info(f"Added key to CPU backend. Hash: {key.chunk_hash}") - logger.info(f"Cache size: {self.current_size_bytes} bytes / " - f"{self.max_cpu_cache_size_bytes} bytes") - - def get(self, key: CacheKey) -> Optional[Any]: + # Add the new item. + if key in self.cache: + old_value = self.cache.pop(key) + self.current_size_bytes -= self._get_value_size(old_value) + del old_value + + self.cache[key] = value + value_size = self._get_value_size(value) + self.current_size_bytes += value_size + logger.info(f"Added key: {key} (size:{value_size}) to CPU backend.") + logger.info(f"Cache size: {self.current_size_bytes} bytes / " + f"{self.max_cpu_cache_size_bytes} bytes") + return True + + def get(self, key: CpuChunkId) -> Optional[Any]: """ Gets the value for a given key and marks it as recently used. """ - with self.lock: - if key in self.cache: - # Mark as most recently used. - self.cache.move_to_end(key) - return self.cache[key] - return None - - def contains(self, key: CacheKey, pin_on_hit: bool = False) -> bool: - """ - Checks if a key exists in the cache. - - If the key is found, it's marked as recently used. If `pin_on_hit` - is True, the key is also pinned to prevent eviction. - """ - with self.lock: - if key in self.cache: - # Mark as most recently used, since this is an access. - self.cache.move_to_end(key) - if pin_on_hit: - self.pin_counts[key] = self.pin_counts.get(key, 0) + 1 - logger.info(f"Pinned key on hit. Hash: {key.chunk_hash}, " - f"New count: {self.pin_counts[key]}") - return True - return False - - def maybe_unpin_keys(self, keys: List[CacheKey]) -> Tuple[int, int]: - """ - Unpins a list of keys. - - Decrements the pin count for each key. If a key's count reaches zero, - it is fully unpinned and becomes eligible for eviction. - """ - with self.lock: - unpinned_count = 0 - found_count = 0 - for key in keys: - if key in self.pin_counts: - found_count += 1 - self.pin_counts[key] -= 1 - logger.info( - f"Decremented pin count for key. Hash: {key.chunk_hash}, " - f"New count: {self.pin_counts[key]}") - if self.pin_counts[key] == 0: - del self.pin_counts[key] - unpinned_count += 1 - logger.info( - f"Unpinned key completely. Hash: {key.chunk_hash}") - return unpinned_count, found_count + if key in self.cache: + return self.cache[key] + return None + + def reclaim_unoccupied_chunks(self, occupied_chunk_ids: list[CpuChunkId]): + chunk_ids = list(self.cache.keys()) + unoccupied_chunk_ids = [ + chunk_id for chunk_id in chunk_ids + if chunk_id not in occupied_chunk_ids + ] + reclaimed_size_bytes = 0 + for chunk_id in unoccupied_chunk_ids: + dummy_value = self.cache.pop(chunk_id) + reclaimed_size_bytes += self._get_value_size(dummy_value) + del dummy_value + self.current_size_bytes -= reclaimed_size_bytes + + logger.info( + f" Reclaimed {len(unoccupied_chunk_ids)} unoccupied chunks, " + f"with {reclaimed_size_bytes} bytes.") diff --git a/tpu_inference/distributed/tpu_connector_local.py b/tpu_inference/distributed/tpu_connector_local.py index 4c4aa63d3..20e662daa 100644 --- a/tpu_inference/distributed/tpu_connector_local.py +++ b/tpu_inference/distributed/tpu_connector_local.py @@ -88,6 +88,7 @@ import copy import os import time +from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, Optional, get_args @@ -100,6 +101,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ KVConnectorStats +from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput @@ -116,6 +118,7 @@ TokenProcessor, cdiv, get_default_kv_connector_staging_buffer_tokens, get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) +from .cpu_chunk_manager import LRUOffloadingManager from .local_cpu_backend import LocalCPUBackend EngineId = str @@ -131,6 +134,8 @@ BLOCK_SIZE_BUCKETS = [1, 2, 4, 8, 16] +CPU_CHUNK_ID = int + # we keep our operations at vllm's block granularity, # and provide the following three preferences when handling # the last partial block during save: @@ -147,6 +152,7 @@ class SaveSpec: # total processed tokens for matching / saving num_total_tokens: int src_blocks: list[int] + dst_chunks: list[int] # final save for the (newly) finished request is_final_save: bool = False # A direct signal to the worker to skip the data transfer but still @@ -158,6 +164,7 @@ class SaveSpec: class LoadSpec: """Internal scheduler state for a potential load operation.""" num_matched_tokens: int + src_chunks: list[int] dst_blocks: list[int] can_load: bool = False num_skip_leading_tokens: int = 0 @@ -188,6 +195,7 @@ def __repr__(self) -> str: f", num_matched_tokens={self.load_spec.num_matched_tokens}, " f"can_load={self.load_spec.can_load}, " f"num_skip_leading_tokens={self.load_spec.num_skip_leading_tokens}, " + f"src_chunks={self.load_spec.src_chunks}, " f"dst_blocks={self.load_spec.dst_blocks}") save_info = f"save_spec_exists={self.save_spec is not None}" if self.save_spec: @@ -196,6 +204,7 @@ def __repr__(self) -> str: f"num_total_tokens={self.save_spec.num_total_tokens}, " f"is_final_save={self.save_spec.is_final_save}, " f"skip_save={self.save_spec.skip_save}, " + f"dst_chunks={self.save_spec.dst_chunks}, " f"src_blocks={self.save_spec.src_blocks}") return (f"TPUReqMeta(req_id={self.req_id}, " @@ -270,20 +279,22 @@ def __post_init__(self): def reset(self): # Must be serializable - self.data: dict[str, dict[str, int]] = { + self.data: dict[str, dict[str, list[int]]] = { "finished_save_blocks": dict(), "finished_load_blocks": dict(), } - def record_save(self, req: ReqId, num_finished_blocks: int): + def record_save(self, req: ReqId, saved_chunk_ids: list[int]): if req not in self.data["finished_save_blocks"]: - self.data["finished_save_blocks"][req] = 0 - self.data["finished_save_blocks"][req] += num_finished_blocks + self.data["finished_save_blocks"][req] = [] + self.data["finished_save_blocks"][req].extend( + copy.deepcopy(saved_chunk_ids)) - def record_load(self, req: ReqId, num_finished_blocks: int): + def record_load(self, req: ReqId, loaded_chunk_ids: list[int]): if req not in self.data["finished_load_blocks"]: - self.data["finished_load_blocks"][req] = 0 - self.data["finished_load_blocks"][req] += num_finished_blocks + self.data["finished_load_blocks"][req] = [] + self.data["finished_load_blocks"][req].extend( + copy.deepcopy(loaded_chunk_ids)) def clone_and_reset(self) -> "KVOffloadConnectorStats": old = copy.copy(self) @@ -475,6 +486,12 @@ def get_num_free_staging_blocks(self) -> int: def get_num_used_staging_blocks(self) -> int: return self._num_blocks_for_load + self._num_blocks_for_save + def get_num_used_save_staging_blocks(self, req_id: ReqId) -> int: + return self._blocks_for_save.get(req_id, 0) + + def get_num_used_load_staging_blocks(self, req_id: ReqId) -> int: + return self._blocks_for_load.get(req_id, 0) + def allocate(self, req_id: ReqId, num_blocks: int, usage: Literal["load", "save"]) -> int: if num_blocks < 0: @@ -598,6 +615,9 @@ def __init__(self, vllm_config: "VllmConfig"): self.config = vllm_config.kv_transfer_config self.block_size = vllm_config.cache_config.block_size + # offloading manager + self.offload_manager = LRUOffloadingManager(num_cpu_chunks=1024) + self._request_trackers: dict[ReqId, RequestTracker] = {} # This dictionary holds the full vLLM Request object for all requests # that are currently in a running state (i.e., have been scheduled but @@ -609,10 +629,15 @@ def __init__(self, vllm_config: "VllmConfig"): # {reqid: total_num_matched_tokens_in_cpu_backend} self._external_cache_hits: dict[ReqId, int] = {} - self.cpu_backend = LocalCPUBackend() + + # request ID -> set(block hashes being saved/loaded) + self._reqs_being_saved = defaultdict[ReqId, set[CPU_CHUNK_ID]](set) + self._reqs_being_loaded = defaultdict[ReqId, set[CPU_CHUNK_ID]](set) + model_name = self.vllm_config.model_config.model self.token_processor = TokenProcessor(model_name=model_name, chunk_size=self.block_size) + self.decode_save = os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0") == "1" # NOTE(jcgu): currently, let's nail on chunk_size == block_size # chunk_size == n * block_size lead to @@ -656,6 +681,17 @@ def __init__(self, vllm_config: "VllmConfig"): f"partial_block_dynamic_pad_lower_limit={self.partial_block_dynamic_pad_lower_limit}, " f"num_staging_blocks={self.num_staging_blocks}.") + def _get_request_block_hashes(self, req: "Request") -> list[BlockHash]: + # request's original block_hashes do not include the last partial block + + # TODO(jcgu): switch back to self-hash function + # prompt_token_ids = req.prompt_token_ids + # request_keys = self.token_processor.process_tokens(prompt_token_ids) + # hashes = [hash for _, _, hash in request_keys] + # return hashes + + return req.block_hashes + def get_num_new_matched_tokens( self, request: "Request", @@ -665,61 +701,57 @@ def get_num_new_matched_tokens( Checks for external KV cache hit against the local CPU backend. """ assert num_computed_tokens % self.block_size == 0, f"{num_computed_tokens} % {self.block_size} != 0" + # get block_hash + block_hashes = self._get_request_block_hashes(request) + num_total_blocks = len(block_hashes) prompt_token_ids = request.prompt_token_ids logger.info(f"Request {request.request_id}: Checking for cache hit. " f"Prompt length: {len(prompt_token_ids)}, " + f"Block_hashes ({num_total_blocks})," f"Already computed tokens: {num_computed_tokens}. ") - # Generate keys for the incoming request's tokens - request_keys = self.token_processor.process_tokens(prompt_token_ids) - num_matched_blocks = 0 - num_matched_tokens = 0 - # The generator needs to be consumed to count. - keys = list(request_keys) - for start_token_idx, end_token_idx, key in keys: - logger.info( - f" Processing chunk {start_token_idx}-{end_token_idx} with hash {key.chunk_hash}" - ) - if self.cpu_backend.contains(key, pin_on_hit=True): - # NOTE: each key maps to a cpu_chunk which equals to a block - num_matched_tokens = end_token_idx - num_matched_blocks += 1 - logger.info( - f" -> HIT. Total matched tokens so far: {num_matched_tokens}" - ) - else: - # Stop at the first cache miss - logger.info(" -> MISS. Stopping search.") - break - + # look for blocks in the cache + num_hits = self.offload_manager.lookup(block_hashes) + matched_block_hashes = block_hashes[:num_hits] + self.offload_manager.touch(block_hashes) + num_matched_blocks = len(matched_block_hashes) + num_matched_tokens = min(num_matched_blocks * self.block_size, + len(prompt_token_ids)) + num_computed_blocks = num_computed_tokens // self.block_size + num_blocks_to_load = num_matched_blocks - num_computed_blocks logger.info( - f"Request {request.request_id}: Found {num_matched_tokens} (out of {len(prompt_token_ids)} prompt tokens) matched tokens in CPU backend." + f"Request {request.request_id}: Found {num_matched_tokens} (out of {len(prompt_token_ids)} prompt tokens) matched tokens ({num_matched_blocks} blocks) in CPU backend (computed_blocks: {num_computed_blocks}, blocks_to_load: {num_blocks_to_load})." ) - assert num_matched_blocks == cdiv(num_matched_tokens, self.block_size) - - if num_matched_tokens > num_computed_tokens: + if num_blocks_to_load > 0: # planning staging blocks for load # NOTE(jcgu): do not worry about the inconsistency of the staging buffer status; # there is only one connector scheduler who is operating on it. num_avail_staging_blocks = self.staging_buffer_manager.get_num_free_staging_blocks( ) - num_skip_blocks = num_computed_tokens // self.block_size - num_blocks_to_load = num_matched_blocks - num_skip_blocks if num_blocks_to_load > num_avail_staging_blocks: # reduce blocks_to_load (and matched tokens) when there are insufficient staging blocks. logger.info( f" Req({request.request_id}) found {num_matched_blocks} blocks ({num_matched_tokens} tokens), but only {num_avail_staging_blocks} staging blocks available." ) num_blocks_to_load = num_avail_staging_blocks - num_matched_tokens = num_blocks_to_load * self.block_size + num_computed_tokens + num_matched_tokens = (num_blocks_to_load + + num_computed_blocks) * self.block_size # still have something to load if num_blocks_to_load > 0: + # get the src chunk ids to load + block_hashes_to_load = block_hashes[num_computed_blocks:( + num_computed_blocks + num_blocks_to_load)] + chunks_to_load = self.offload_manager.prepare_load( + block_hashes_to_load) + src_chunk_ids = [chunk.chunk_id for chunk in chunks_to_load] + # NOTE(jcgu): fill real dst_blocks later when blocks get allocated. dummy_dst_blocks = [-1] * num_blocks_to_load self.load_specs[request.request_id] = LoadSpec( num_matched_tokens=num_matched_tokens, + src_chunks=src_chunk_ids, dst_blocks=dummy_dst_blocks, num_skip_leading_tokens=num_computed_tokens, ) @@ -729,6 +761,8 @@ def get_num_new_matched_tokens( usage="load") assert num_allocated_blocks == num_blocks_to_load >= 0, f" failed to allocate {num_allocated_blocks} (load) staging blocks for request {request.request_id}, expected {num_blocks_to_load}." + # record the matched tokens in the cache, it will be needed in + # init save_spec self._external_cache_hits[ request.request_id] = num_matched_tokens @@ -768,6 +802,7 @@ def get_num_new_matched_tokens( f"Request {request.request_id}: After accounting for {num_computed_tokens} computed tokens, reporting {num_to_load} tokens to load." ) + # external_computed_tokens, load_kv_async return num_to_load, False def _adjust_last_partial_block(self, @@ -795,37 +830,38 @@ def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): """ - This hook is not used for the save logic. Trackers are created - and managed within `build_connector_meta`. + This hook is not used for the save logic. + Update the dst_blocks in the load_spec """ logger.info( f"TPUConnectorScheduler: Entering update_state_after_alloc Request {request.request_id}: Scheduler allocated " f"{num_external_tokens} external tokens.") self._unfinished_requests[request.request_id] = request - if num_external_tokens > 0: - if request.request_id in self.load_specs: - load_spec = self.load_specs[request.request_id] - # Get the real loading block ids . - all_blocks = blocks.get_block_ids()[0] - assert load_spec.num_skip_leading_tokens % self.block_size == 0 - skip_leading_blocks = load_spec.num_skip_leading_tokens // self.block_size - - # NOTE(jcgu): I think we do not need the adjustment here; - # for load, we should load all the (matched_tokens - skipping_tokens), - # since it's been reported to vllm scheduler - - total_matched_blocks = len( - load_spec.dst_blocks) + skip_leading_blocks - assert total_matched_blocks == cdiv( - load_spec.num_matched_tokens, self.block_size - ), f"{total_matched_blocks} != {load_spec.num_matched_tokens}" - dst_blocks = all_blocks[ - skip_leading_blocks:total_matched_blocks] - load_spec.dst_blocks = dst_blocks - load_spec.can_load = True - logger.info( - f"Request {request.request_id} ({len(dst_blocks)} dst_blocks) is ready to load." - ) + if num_external_tokens == 0: + return + if request.request_id in self.load_specs: + block_hashes = self._get_request_block_hashes(request) + all_blocks = blocks.get_block_ids()[0] + logger.info( + f" Request: {request.request_id} has {len(all_blocks)} blocks / {len(block_hashes)} block hashes.)" + ) + load_spec = self.load_specs[request.request_id] + assert load_spec.num_skip_leading_tokens % self.block_size == 0 + skip_leading_blocks = load_spec.num_skip_leading_tokens // self.block_size + + total_matched_blocks = len( + load_spec.dst_blocks) + skip_leading_blocks + assert total_matched_blocks == cdiv( + load_spec.num_matched_tokens, self.block_size + ), f"{total_matched_blocks} != {load_spec.num_matched_tokens}" + dst_blocks = all_blocks[skip_leading_blocks:total_matched_blocks] + load_spec.dst_blocks = dst_blocks + load_spec.can_load = True + self._reqs_being_loaded[request.request_id] |= set( + load_spec.src_chunks) + logger.info( + f"Request {request.request_id} ({len(dst_blocks)} dst_blocks) is ready to load." + ) def _prepare_req_meta( self, @@ -838,7 +874,15 @@ def _prepare_req_meta( needed and prepares the metadata. Also performs the transactional update of the tracker's save state. """ - num_total_tokens = len(tracker.token_ids) + req_id = tracker.req_id + _request = self._unfinished_requests[req_id] + block_hashes = self._get_request_block_hashes(_request) + self.offload_manager.touch(block_hashes) + + # only consider the tokens covered by block_hashes + num_total_blocks = len(block_hashes) + num_total_tokens = min(num_total_blocks * self.block_size, + len(tracker.token_ids)) num_full_blocks = num_total_tokens // self.block_size num_full_blocks_tokens = num_full_blocks * self.block_size # adjust last partial block @@ -920,9 +964,9 @@ def _prepare_req_meta( # will no such an issue. num_skip_leading_blocks = tracker.save_watermark // self.block_size num_skip_leading_tokens = num_skip_leading_blocks * self.block_size + num_blocks_to_save = adjusted_num_total_blocks - num_skip_leading_blocks # planning staging blocks for save - num_blocks_to_save = adjusted_num_total_blocks - num_skip_leading_blocks num_avail_staging_blocks = self.staging_buffer_manager.get_num_free_staging_blocks( ) if num_blocks_to_save > num_avail_staging_blocks: @@ -935,36 +979,54 @@ def _prepare_req_meta( adjusted_num_total_tokens = adjusted_num_total_blocks * self.block_size if num_blocks_to_save > 0: - src_block_ids = tracker.block_ids[ + block_hashes_to_save = block_hashes[ num_skip_leading_blocks:adjusted_num_total_blocks] - # This is a real save operation. - save_spec = SaveSpec( - num_skip_leading_tokens=num_skip_leading_tokens, - num_total_tokens=adjusted_num_total_tokens, - is_final_save=is_finished, - skip_save=False, - src_blocks=src_block_ids, - ) - num_allocated_blocks = self.staging_buffer_manager.allocate( - tracker.req_id, - num_blocks=num_blocks_to_save, - usage="save") - assert num_allocated_blocks == num_blocks_to_save >= 0, f" failed to allocate {num_allocated_blocks} (save) staging blocks for request {tracker.req_id}, expected {num_blocks_to_save}." - if adjusted_num_total_tokens > tracker.save_watermark: - logger.info( - f" -> Old watermark {tracker.save_watermark}, new save_watermark count: {adjusted_num_total_tokens}" + allocate_output = self.offload_manager.allocate_for_save( + block_hashes_to_save) + if allocate_output is not None: + # there are enough chunks to save + chunks_for_save, chunk_idxs = allocate_output + assert num_blocks_to_save == len(chunks_for_save) + src_block_ids = tracker.block_ids[ + num_skip_leading_blocks:adjusted_num_total_blocks] + + dst_chunks = [chunk.chunk_id for chunk in chunks_for_save] + src_blocks = [src_block_ids[idx] for idx in chunk_idxs] + + # This is a real save operation. + save_spec = SaveSpec( + num_skip_leading_tokens=num_skip_leading_tokens, + num_total_tokens=adjusted_num_total_tokens, + is_final_save=is_finished, + skip_save=False, + src_blocks=src_blocks, + dst_chunks=dst_chunks, ) - tracker.save_watermark = adjusted_num_total_tokens + self._reqs_being_saved[req_id] |= set(dst_chunks) + num_allocated_blocks = self.staging_buffer_manager.allocate( + tracker.req_id, + num_blocks=num_blocks_to_save, + usage="save") + assert num_allocated_blocks == num_blocks_to_save >= 0, f" failed to allocate {num_allocated_blocks} (save) staging blocks for request {tracker.req_id}, expected {num_blocks_to_save}." + + if adjusted_num_total_tokens > tracker.save_watermark: + logger.info( + f" -> Old watermark {tracker.save_watermark}, new save_watermark count: {adjusted_num_total_tokens}" + ) + tracker.save_watermark = adjusted_num_total_tokens if is_finished and save_spec is None: # For finished requests, there must be a no-op save to update the state in the worker side. # This is a "completion-only" signal because should_save is False. # NOTE(jcgu): num_total_tokens will be used to unpin tokens; # apply the number of saved tokens; + # TODO(jcgu): rm the no-op save, since save status has been updated + # through kv_connector_output.kv_connector_stats save_spec = SaveSpec( num_skip_leading_tokens=tracker.save_watermark, num_total_tokens=tracker.save_watermark, src_blocks=[], + dst_chunks=[], is_final_save=True, skip_save=True, ) @@ -996,10 +1058,7 @@ def build_connector_meta( ) for finished_req_id in scheduler_output.finished_req_ids: logger.info(f" - Processing finished req: {finished_req_id}") - # Pop tracker and other state first. - tracker = self._request_trackers.pop(finished_req_id, None) - self._unfinished_requests.pop(finished_req_id, None) - self.load_specs.pop(finished_req_id, None) + tracker = self._request_trackers[finished_req_id] if not tracker: logger.warning( @@ -1018,6 +1077,11 @@ def build_connector_meta( ) metadata.requests_meta.append(req_meta) + # Pop tracker and other state first. + self._request_trackers.pop(finished_req_id, None) + self._unfinished_requests.pop(finished_req_id, None) + self.load_specs.pop(finished_req_id, None) + # Phase 2: Process newly scheduled requests # This block handles requests being scheduled for the very first time. # It creates the initial RequestTracker and prepares the first work order. @@ -1026,7 +1090,11 @@ def build_connector_meta( ) for request in scheduler_output.scheduled_new_reqs: req_id = request.req_id - logger.info(f" - Processing new req: {req_id}") + + _request = self._unfinished_requests[req_id] + logger.info( + f" - Processing new req: {req_id}, {len(_request.block_hashes)} block_hashes." + ) num_new_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] @@ -1053,6 +1121,7 @@ def build_connector_meta( # Create and store the tracker, which will maintain the request's # state for its entire lifetime. assert req_id not in self._request_trackers, f"Request {req_id} already has a tracker." + # TODO(jcgu): reduce duplicated info in request tracker tracker = RequestTracker( req_id=req_id, prompt_len=len(request.prompt_token_ids), @@ -1087,9 +1156,12 @@ def build_connector_meta( logger.info( f"Phase 3: Processing {len(cached_reqs.req_ids)} cached requests.") for i, req_id in enumerate(cached_reqs.req_ids): - logger.info(f" - Processing cached req: {req_id}") tracker = self._request_trackers[req_id] full_request = self._unfinished_requests.get(req_id) + _block_hashes = full_request.block_hashes + logger.info( + f" - Processing cached req: {req_id}, {len(_block_hashes)} block_hashes." + ) if full_request is None: logger.warning( @@ -1167,22 +1239,48 @@ def update_connector_output(self, connector_output: KVConnectorOutput): KVOffloadConnectorStats) assert "finished_save_blocks" in connector_output.kv_connector_stats.data assert "finished_load_blocks" in connector_output.kv_connector_stats.data - for req_id, nfb in connector_output.kv_connector_stats.data[ + for req_id, saved_chunk_ids in connector_output.kv_connector_stats.data[ "finished_save_blocks"].items(): - logger.info(f" finished_save_blocks for {req_id}: {nfb}") - self.staging_buffer_manager.free(req_id, - usage="save", - num_finished_blocks=nfb) - for req_id, nfb in connector_output.kv_connector_stats.data[ + num_saved_chunks = len(saved_chunk_ids) + logger.info( + f" finished_save_blocks for {req_id}: {saved_chunk_ids}") + # free staging blocks + self.staging_buffer_manager.free( + req_id, usage="save", num_finished_blocks=num_saved_chunks) + # update in-flight save + for saved_chunk_id in saved_chunk_ids: + assert saved_chunk_id in self._reqs_being_saved[req_id] + self._reqs_being_saved[req_id].remove(saved_chunk_id) + if len(self._reqs_being_saved[req_id]) == 0: + self._reqs_being_saved.pop(req_id, None) + # update the status of occupied cpu chunks + self.offload_manager.mark_completion(saved_chunk_ids, "save") + + for req_id, loaded_chunk_ids in connector_output.kv_connector_stats.data[ "finished_load_blocks"].items(): - logger.info(f" finished_load_blocks for {req_id}: {nfb}") - self.staging_buffer_manager.free(req_id, - usage="load", - num_finished_blocks=nfb) - - # clean up staging blocks for the finished requests + num_loaded_chunks = len(loaded_chunk_ids) + logger.info( + f" finished_load_blocks for {req_id}: {num_loaded_chunks}" + ) + self.staging_buffer_manager.free( + req_id, + usage="load", + num_finished_blocks=num_loaded_chunks) + # update in-flight save + for loaded_chunk_id in loaded_chunk_ids: + assert loaded_chunk_id in self._reqs_being_loaded[req_id] + self._reqs_being_loaded[req_id].remove(loaded_chunk_id) + if len(self._reqs_being_loaded[req_id]) == 0: + self._reqs_being_loaded.pop(req_id, None) + # update the status of occupied cpu chunks + self.offload_manager.mark_completion(loaded_chunk_ids, "load") + + # clean up the status of the finished requests # save for req_id in connector_output.finished_sending or []: + if req_id in self._reqs_being_saved: + assert len(self._reqs_being_saved[req_id]) == 0 + self._reqs_being_saved.pop(req_id) num_freed_blocks = self.staging_buffer_manager.free(req_id, usage="save") logger.info( @@ -1191,6 +1289,9 @@ def update_connector_output(self, connector_output: KVConnectorOutput): # load for req_id in connector_output.finished_recving or []: + if req_id in self._reqs_being_loaded: + assert len(self._reqs_being_loaded[req_id]) == 0 + self._reqs_being_loaded.pop(req_id) num_freed_blocks = self.staging_buffer_manager.free(req_id, usage="load") logger.info( @@ -1203,14 +1304,33 @@ def request_finished( block_ids: list[int], ) -> tuple[bool, Optional[dict[str, Any]]]: """ - Signals to the scheduler that the connector is handling the finished - request asynchronously. The actual logic to prepare the save operation - occurs in `build_connector_meta`. + Called when a request has finished, before its blocks are freed. + + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + return: + delay_free_blocks, kv_xfer_params """ logger.info("TPUConnectorScheduler: Entering request_finished") # Return True to indicate the request is being saved asynchronously # and its blocks should not be freed yet. - return True, None + + req_id = request.request_id + if req_id in self._reqs_being_saved and len( + self._reqs_being_saved[req_id]) > 0: + return True, None + if req_id in self._reqs_being_loaded and len( + self._reqs_being_loaded[req_id]) > 0: + return True, None + + logger.info(f"TPUConnectorScheduler: finished request: {req_id}") + self._reqs_being_saved.pop(req_id, None) + self._reqs_being_loaded.pop(req_id, None) + + return False, None class TPUConnectorWorker: @@ -1240,10 +1360,7 @@ def __init__(self, vllm_config: VllmConfig, connector: "TPUConnector"): self.swap_in_fn: KVCacheSwapFn = None self.swap_out_fn: KVCacheSwapFn = None - self.host = self.config.kv_ip - self.kv_transfer_port = self.config.kv_port - - # Get the singleton instance of the CPU backend. + # cpu cache self.cpu_backend = LocalCPUBackend() # The worker needs its own token processor to generate keys. model_name = self.vllm_config.model_config.model @@ -1258,7 +1375,6 @@ def __init__(self, vllm_config: VllmConfig, connector: "TPUConnector"): thread_name_prefix="tpu_saver") self.finished_save_reqs: set[ReqId] = set() self.finished_load_reqs: set[ReqId] = set() - self._tokens_to_unpin: dict[ReqId, list[int]] = {} # Tracks if wait_for_save has been called for the current step's metadata. self._processed_save_for_step = False @@ -1313,7 +1429,8 @@ def register_runner(self, runner: TPUModelRunner): "TPUConnectorWorker registered with no KV caches.") # Pre-compile the JIT functions for KV cache swapping. - self._precompile_kv_swap_operations() + if self.use_bucketed_swap_ops: + self._precompile_kv_swap_operations() def _decompose_into_buckets(self, num_blocks: int) -> list[int]: """ @@ -1567,6 +1684,8 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], return req_id blocks_to_save = save_spec.src_blocks + dst_chunks = save_spec.dst_chunks + num_total_tokens = save_spec.num_total_tokens num_skip_leading_tokens = save_spec.num_skip_leading_tokens num_blocks_to_save = len(blocks_to_save) @@ -1587,7 +1706,8 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], f"num_skip_leading_tokens={num_skip_leading_tokens}, " f"num_total_tokens={num_total_tokens}, " f"num_tokens_to_save={num_tokens_to_save}, " - f"blocks_to_save len={len(blocks_to_save)}") + f"blocks_to_save({len(blocks_to_save)}: {blocks_to_save}, " + f"dst_chunks({len(dst_chunks)}: {dst_chunks} ") if not blocks_to_save and tokens_to_save: logger.warning( @@ -1663,43 +1783,24 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], ) post_transfer_start_time = time.time() - # Generate keys for the entire token sequence to get absolute positions. This to ensure that the delta - # tokens that is about to be captured in the cache are correctly mapped. These keys will be recreated - # during get_finished() to unpin the correct keys. - all_keys_generator = self.token_processor.process_tokens( - process_token_ids) - all_keys = list(all_keys_generator) - # Filter for keys that correspond to the new data we are saving. - relevant_keys = [] - for abs_start_token_idx, abs_end_token_idx, key in all_keys: - if abs_start_token_idx >= num_skip_leading_tokens: - relevant_keys.append( - (abs_start_token_idx, abs_end_token_idx, key)) - - if relevant_keys: - assert len( - relevant_keys - ) == num_blocks_to_save, f"{len(relevant_keys)} != {num_blocks_to_save}" - for i in range(num_blocks_to_save): - abs_start_token_idx, abs_end_token_idx, key = relevant_keys[ - i] - cur_chunk_cross_layers = [ - chunks_on_cpu[j][i] for j in range(self.num_layers) - ] - self.cpu_backend.add(key, cur_chunk_cross_layers) - logger.info( - f"Request {req_id}: Saving to CPU chunk: " - f"abs_start_token_idx={abs_start_token_idx}, abs_end_token_idx={abs_end_token_idx}, " - f"chunk_hash={key.chunk_hash}, " - f" local_chunk_idx={i}") - logger.info( - f"Request {req_id}: Added {len(relevant_keys)} keys to CPU backend." - ) + for i in range(num_blocks_to_save): + chunk_id = dst_chunks[i] + cur_chunk_cross_layers = [ + chunks_on_cpu[j][i] for j in range(self.num_layers) + ] + self.cpu_backend.add(chunk_id, cur_chunk_cross_layers) + logger.info(f"Request {req_id}: Saving to CPU chunk: " + f"chunk_id={chunk_id}, " + f" local_chunk_idx={i}") + + logger.info( + f"Request {req_id}: Added {num_blocks_to_save} chunks to CPU backend." + ) post_transfer_duration = time.time() - post_transfer_start_time logger.info( - f"Request {req_id}: e2e host processing of {len(relevant_keys)} keys took {post_transfer_duration:.4f} seconds." + f"Request {req_id}: e2e host processing of {num_blocks_to_save} chunks took {post_transfer_duration:.4f} seconds." ) except Exception as e: logger.error(f"Error saving blocks for request {req_id}: {e}", @@ -1742,10 +1843,7 @@ def wait_for_save(self): logger.info( f"Request {meta.req_id}: Final save is a no-op. Marking as finished." ) - self.finished_save_reqs.add(meta.req_id) - self._tokens_to_unpin[ - meta.req_id] = meta.token_ids[:meta.save_spec. - num_total_tokens] + # self.finished_save_reqs.add(meta.req_id) continue # If there are tokens to save, submit the task to the thread pool. @@ -1775,16 +1873,13 @@ def wait_for_save(self): if len(meta.save_spec.src_blocks) > 0: self.offload_stats.record_save( req=finished_req_id, - num_finished_blocks=len(meta.save_spec.src_blocks)) + saved_chunk_ids=meta.save_spec.dst_chunks) if meta.save_spec and meta.save_spec.is_final_save: logger.info( f"Request {finished_req_id}: Final save completed. Marking as finished." ) self.finished_save_reqs.add(finished_req_id) - self._tokens_to_unpin[ - finished_req_id] = meta.token_ids[:meta.save_spec. - num_total_tokens] except Exception as e: logger.error(f"A save operation failed: {e}", exc_info=True) @@ -1825,6 +1920,7 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: request_load_start_time = time.time() logger.info("TPUConnectorWorker: Starting KV cache load process.") dst_blocks = meta.load_spec.dst_blocks + src_chunks = meta.load_spec.src_chunks num_blocks_to_load = len(dst_blocks) num_matched_tokens = meta.load_spec.num_matched_tokens num_skip_leading_tokens = meta.load_spec.num_skip_leading_tokens @@ -1862,48 +1958,20 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: f"Fetching delta of {num_tokens_to_load_delta} tokens from cache for " f"{num_blocks_to_load} blocks.") - # 1. Generate keys for the entire matched prefix to find the right - # chunks in the backend. - keys_generator = self.token_processor.process_tokens( - meta.token_ids[:num_matched_tokens]) - - # 2. Assemble the per-layer data for the delta tokens on the CPU. + # Assemble the per-layer data for the delta tokens on the CPU. # We create a list of lists, where the outer list represents layers # and the inner lists will hold the data chunks for that layer. assembled_kv_on_cpu = [[] for _ in range(self.num_layers)] - # Fetch and slice chunks from the backend. - for start_token_idx, end_token_idx, key in keys_generator: - # This chunk is entirely before the delta, so we can skip it. - if end_token_idx <= num_skip_leading_tokens: - continue - # This chunk is entirely after the delta. - if start_token_idx >= num_matched_tokens: - continue - - cached_value = self.cpu_backend.get(key) + # Fetch and chunks from the backend. + for i in range(num_blocks_to_load): + src_chunk_id = src_chunks[i] + cached_value = self.cpu_backend.get(src_chunk_id) if cached_value: - # Calculate the precise slice needed from this specific chunk. - # rel_start_token_idx is the index within this chunk where the delta tokens begin. - if start_token_idx < num_skip_leading_tokens: - assert False, f"start_token_idx {start_token_idx} should not be less than num_skip_leading_tokens {num_skip_leading_tokens}, when cpu_chunk_size == block_size" - rel_start_token_idx = num_skip_leading_tokens - start_token_idx - rel_end_token_idx = end_token_idx - num_skip_leading_tokens - for i in range(self.num_layers): - # NOTE(jcgu): if only one block to load (and it's a padded block), - # then rel_end_token_idx will not be inaccurate (< block_size). - # Slice the jax array fetched from the backend. - sliced_chunk = jax.lax.slice_in_dim( - cached_value[i], - rel_start_token_idx, - rel_end_token_idx, - axis=0) - assembled_kv_on_cpu[i].append(sliced_chunk) - else: - for i in range(self.num_layers): - assembled_kv_on_cpu[i].append(cached_value[i]) + for j in range(self.num_layers): + assembled_kv_on_cpu[j].append(cached_value[j]) else: logger.error( - f"Cache key {key.chunk_hash} not found in CPU backend for request {meta.req_id}. Inconsistent state detected." + f"Chunk[{src_chunk_id}] not found in CPU backend for request {meta.req_id}. Inconsistent state detected." ) return @@ -1938,8 +2006,8 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: load_times.append(time.time() - request_load_start_time) self.finished_load_reqs.add(meta.req_id) if num_blocks_to_load > 0: - self.offload_stats.record_load( - req=meta.req_id, num_finished_blocks=num_blocks_to_load) + self.offload_stats.record_load(req=meta.req_id, + loaded_chunk_ids=src_chunks) if load_times: aggregate_load_time = sum(load_times) @@ -1974,37 +2042,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.wait_for_save() finished_saves = self.finished_save_reqs - logger.info(f"Finished saves to report: {finished_saves}") - - # Unpinning logic: - # A finished request consists of N prompt tokens and M generated tokens. - # The N prompt tokens were pinned during the initial lookup in - # `get_num_new_matched_tokens`. The M generated tokens were never - # pinned, as they were directly added to the cache. - # Here, we generate keys for the full N+M sequence. The call to - # `unpin_keys` will correctly unpin the N prompt keys and perform - # a harmless no-op for the M generated keys, which were never in - # the pinned set to begin with. - keys_to_unpin = [] - for req_id in finished_saves: - if req_id in self._tokens_to_unpin: - tokens = self._tokens_to_unpin.pop(req_id) - keys_generator = self.token_processor.process_tokens(tokens) - unpin_keys = [key for _, _, key in keys_generator] - keys_to_unpin.extend(unpin_keys) - logger.info( - f"Generated {len(unpin_keys)} keys to unpin for request {req_id}." - ) - - if keys_to_unpin: - unpinned_count, found_count = self.cpu_backend.maybe_unpin_keys( - keys_to_unpin) - logger.info( - f"Unpinned {unpinned_count} out of {found_count} existing keys (Request to unpin {len(keys_to_unpin)} keys)." - ) - self.finished_save_reqs = set() - finished_loads = self.finished_load_reqs self.finished_load_reqs = set() logger.info(f"Finished saves: {finished_saves}, " From 345ce3641c20d76d80ab11795a68acf2882a0439 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Tue, 18 Nov 2025 07:39:15 +0000 Subject: [PATCH 02/19] rename offload files Signed-off-by: Juncheng Gu --- examples/gke/benchmarks/README.md | 2 +- .../gke/benchmarks/deploy-cpu-offload.yaml | 2 +- examples/gke/pod_tpu_commons_cpu_offload.yaml | 2 +- ..._tpu_commons_cpu_offload_verification.yaml | 6 +- .../gke/pod_tpu_host_offload_unit_tests.yaml | 2 +- ...offline_inference_kv_cache_verification.py | 6 +- .../cpu_offloading_cache_util_test.py | 3 +- .../cpu_offloading_scheduler_test.py | 12 +- .../distributed/cpu_offloading_worker_test.py | 38 +-- .../host_offloading_accuracy_test.py | 4 +- .../host_offloading_precompile_test.py | 6 +- tests/distributed/local_cpu_backend_test.py | 4 +- tpu_inference/distributed/offload/__init__.py | 0 .../cpu_backend.py} | 2 +- .../offload_manager.py} | 153 ++++++++++- .../tpu_offload_connector.py} | 260 ++++-------------- .../{cache_util.py => offload/utils.py} | 4 + tpu_inference/platforms/tpu_jax.py | 7 +- tpu_inference/runner/kv_cache_manager.py | 3 +- tpu_inference/worker/tpu_worker_jax.py | 4 +- 20 files changed, 268 insertions(+), 252 deletions(-) create mode 100644 tpu_inference/distributed/offload/__init__.py rename tpu_inference/distributed/{local_cpu_backend.py => offload/cpu_backend.py} (98%) rename tpu_inference/distributed/{cpu_chunk_manager.py => offload/offload_manager.py} (53%) rename tpu_inference/distributed/{tpu_connector_local.py => offload/tpu_offload_connector.py} (90%) rename tpu_inference/distributed/{cache_util.py => offload/utils.py} (99%) diff --git a/examples/gke/benchmarks/README.md b/examples/gke/benchmarks/README.md index 29b16b5f0..9d1136637 100644 --- a/examples/gke/benchmarks/README.md +++ b/examples/gke/benchmarks/README.md @@ -33,7 +33,7 @@ kubectl apply -f deploy-baseline.yaml ### Option B: vLLM with TPU Host Offload -This deployment configures vLLM to use a `TPUConnector` for KV cache offload to the host CPU memory. This is specified by the `--kv-transfer-config` argument. +This deployment configures vLLM to use a `TPUOffloadConnector` for KV cache offload to the host CPU memory. This is specified by the `--kv-transfer-config` argument. ```bash kubectl apply -f deploy-cpu-offload.yaml diff --git a/examples/gke/benchmarks/deploy-cpu-offload.yaml b/examples/gke/benchmarks/deploy-cpu-offload.yaml index f7b7647c6..a93ccd3fe 100644 --- a/examples/gke/benchmarks/deploy-cpu-offload.yaml +++ b/examples/gke/benchmarks/deploy-cpu-offload.yaml @@ -21,7 +21,7 @@ spec: imagePullPolicy: Always command: ["/bin/sh", "-c"] args: - - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector_local\"}' --port 8000 --max_num_batched_tokens 2048 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" + - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUOffloadConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector_local\"}' --port 8000 --max_num_batched_tokens 2048 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/examples/gke/pod_tpu_commons_cpu_offload.yaml b/examples/gke/pod_tpu_commons_cpu_offload.yaml index fcc4fe147..2cef3ccfa 100644 --- a/examples/gke/pod_tpu_commons_cpu_offload.yaml +++ b/examples/gke/pod_tpu_commons_cpu_offload.yaml @@ -18,7 +18,7 @@ spec: - --tensor_parallel_size=8 - --max_model_len=1024 - --kv-transfer-config - - '{"kv_connector":"TPUConnector","kv_connector_module_path":"tpu_inference.distributed.tpu_connector_local","kv_role":"kv_both"}' + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.tpu_connector_local","kv_role":"kv_both"}' env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml b/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml index e22baf9b6..b2eb566c6 100644 --- a/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml +++ b/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml @@ -2,10 +2,10 @@ apiVersion: v1 kind: Pod metadata: name: tpu-job-offline-inference - # This pod verifies the correctness of the TPUConnector implementation. + # This pod verifies the correctness of the TPUOffloadConnector implementation. # It runs a script that internally performs two text generations: # 1. A baseline run with a standard vLLM engine. - # 2. A test run with the TPUConnector enabled. + # 2. A test run with the TPUOffloadConnector enabled. # The pod succeeds only if the outputs from both runs are identical, # ensuring that the connector does not alter the model's output. spec: @@ -25,7 +25,7 @@ spec: - --max_model_len=1024 - --seed=42 - --kv-transfer-config - - '{"kv_connector":"TPUConnector","kv_connector_module_path":"tpu_inference.distributed.tpu_connector_local","kv_role":"kv_both"}' + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.tpu_connector_local","kv_role":"kv_both"}' env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/examples/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/gke/pod_tpu_host_offload_unit_tests.yaml index 4a33fce54..ba81c654c 100644 --- a/examples/gke/pod_tpu_host_offload_unit_tests.yaml +++ b/examples/gke/pod_tpu_host_offload_unit_tests.yaml @@ -2,7 +2,7 @@ apiVersion: v1 kind: Pod metadata: name: tpu-job-host-offload-unit-tests - # This pod runs the distributed unit tests for the TPUConnector + # This pod runs the distributed unit tests for the TPUOffloadConnector # and other related functionalities. It executes all tests found in the # tests/distributed/ directory using pytest. spec: diff --git a/examples/offline_inference_kv_cache_verification.py b/examples/offline_inference_kv_cache_verification.py index bcbc87c7c..b93dce149 100644 --- a/examples/offline_inference_kv_cache_verification.py +++ b/examples/offline_inference_kv_cache_verification.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ -This script performs an automated correctness verification for the TPUConnector. +This script performs an automated correctness verification for the TPUOffloadConnector. The verification works by performing a two-stage experiment for multiple prompts: 1. Baseline Run: For each prompt, it first runs a text generation using a @@ -8,7 +8,7 @@ output from this run is considered the "source of truth". 2. Test Run: It then runs the exact same text generation, but this time - with the TPUConnector enabled via the `--kv-transfer-config` argument. + with the TPUOffloadConnector enabled via the `--kv-transfer-config` argument. It runs the generation twice to verify prefix caching. 3. Comparison: The script compares the output from each test run against the @@ -131,7 +131,7 @@ def main(args: dict): time.sleep(10) # 2. Run the test with the local tpu kv connector enabled - print("\n--- Running Test (with TPUConnector) ---") + print("\n--- Running Test (with TPUOffloadConnector) ---") # With the connector, we run generation twice to test the prefix cache test_llm, test_params = setup_llm(args) test_outputs = run_invocations(test_llm, diff --git a/tests/distributed/cpu_offloading_cache_util_test.py b/tests/distributed/cpu_offloading_cache_util_test.py index 911b687c7..175dd0da9 100644 --- a/tests/distributed/cpu_offloading_cache_util_test.py +++ b/tests/distributed/cpu_offloading_cache_util_test.py @@ -4,7 +4,8 @@ import jax.numpy as jnp import numpy as np -from tpu_inference.distributed.cache_util import jitted_insert_kv_cache_slices +from tpu_inference.distributed.offload.utils import \ + jitted_insert_kv_cache_slices def original_jitted_insert_kv_cache_slices( diff --git a/tests/distributed/cpu_offloading_scheduler_test.py b/tests/distributed/cpu_offloading_scheduler_test.py index 03b7e14fc..fda3cb7f8 100644 --- a/tests/distributed/cpu_offloading_scheduler_test.py +++ b/tests/distributed/cpu_offloading_scheduler_test.py @@ -8,9 +8,9 @@ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.request import Request -from tpu_inference.distributed.local_cpu_backend import LocalCPUBackend -from tpu_inference.distributed.tpu_connector_local import ( - ReqId, RequestTracker, TPUConnectorScheduler, _StagingBufferManager) +from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend +from tpu_inference.distributed.offload.tpu_offload_connector import ( + ReqId, RequestTracker, Scheduler, _StagingBufferManager) from tpu_inference.logger import init_logger from .cpu_offloading_worker_test import MockVllmConfig @@ -52,7 +52,7 @@ def clean_backend_instance(): @pytest.fixture def scheduler_factory(): - """Provides a factory function for TPUConnectorScheduler instances.""" + """Provides a factory function for Scheduler instances.""" def _scheduler( block_size: int = _DEFAULT_BLOCK_SIZE, @@ -71,7 +71,7 @@ def _scheduler( if offload_staging_buffer_tokens >= 0: os.environ["TPU_OFFLOAD_STAGING_BUFFER_TOKENS"] = str( offload_staging_buffer_tokens) - return TPUConnectorScheduler(vllm_config) + return Scheduler(vllm_config) return _scheduler @@ -225,7 +225,7 @@ def test_complex_scenario(self): assert manager.get_num_used_staging_blocks() == 0 -class TestTPUConnectorScheduler: +class TestScheduler: def _add_prompt_to_scheduler_cpu_backend(self, scheduler, prompt_tokens): """ add """ diff --git a/tests/distributed/cpu_offloading_worker_test.py b/tests/distributed/cpu_offloading_worker_test.py index afa203f19..8ed20f12b 100644 --- a/tests/distributed/cpu_offloading_worker_test.py +++ b/tests/distributed/cpu_offloading_worker_test.py @@ -14,12 +14,13 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole -from tpu_inference.distributed.local_cpu_backend import LocalCPUBackend -from tpu_inference.distributed.tpu_connector_local import LoadSpec, SaveSpec -from tpu_inference.distributed.tpu_connector_local import \ - TPUConnector as CPUOffloadingConnector -from tpu_inference.distributed.tpu_connector_local import ( - TPUConnectorMetadata, TPUReqMeta) +from tpu_inference.distributed.cpu_backend import LocalCPUBackend +from tpu_inference.distributed.offload.tpu_connector_local import (LoadSpec, + SaveSpec) +from tpu_inference.distributed.offload.tpu_connector_local import \ + TPUOffloadConnector as CPUOffloadingConnector +from tpu_inference.distributed.offload.tpu_connector_local import ( + TPUOffloadConnectorMetadata, TPUReqMeta) from tpu_inference.logger import init_logger from tpu_inference.runner.tpu_jax_runner import TPUModelRunner @@ -62,7 +63,7 @@ class KVTransfer: class TestCpuOffloadingSave(jtu.JaxTestCase): - """Test the save functionality of the TPUConnectorWorker.""" + """Test the save functionality of the TPUOffloadConnectorWorker.""" def setUp(self): super().setUp() @@ -473,7 +474,8 @@ def test_tpu_connector_save( save_spec=save_spec, ) - connector_metadata = TPUConnectorMetadata(requests_meta=[req_meta]) + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) connector.bind_connector_metadata(connector_metadata) logger.info( "Connector metadata bound, calling worker.wait_for_save().") @@ -565,7 +567,7 @@ def test_tpu_connector_multi_step_save( num_blocks_step2: int, ): """ - Tests that the TPUConnectorWorker correctly saves the KV cache in multiple + Tests that the TPUOffloadConnectorWorker correctly saves the KV cache in multiple steps, respecting the save watermark (skip_leading_tokens). """ os.environ[ @@ -621,7 +623,7 @@ def test_tpu_connector_multi_step_save( logger.info( f"Step 1: req_meta_step1.token_ids={req_meta_step1.token_ids}, req_meta_step1.local_block_ids={req_meta_step1.local_block_ids}, req_meta_step1.save_spec.skip_leading_tokens={req_meta_step1.save_spec.num_skip_leading_tokens}" ) - connector_metadata_step1 = TPUConnectorMetadata( + connector_metadata_step1 = TPUOffloadConnectorMetadata( requests_meta=[req_meta_step1]) connector.bind_connector_metadata(connector_metadata_step1) worker.wait_for_save() @@ -681,7 +683,7 @@ def test_tpu_connector_multi_step_save( logger.info( f"Step 2: req_meta_step2.token_ids={req_meta_step2.token_ids}, req_meta_step2.local_block_ids={req_meta_step2.local_block_ids}, req_meta_step2.save_spec.skip_leading_tokens={req_meta_step2.save_spec.num_skip_leading_tokens}" ) - connector_metadata_step2 = TPUConnectorMetadata( + connector_metadata_step2 = TPUOffloadConnectorMetadata( requests_meta=[req_meta_step2]) # Manually reset worker state to simulate a new scheduler step @@ -781,17 +783,17 @@ def test_tpu_connector_load( num_computed_blocks: int = 0, ): """ - Tests that the TPUConnectorWorker correctly loads only the delta of + Tests that the TPUOffloadConnectorWorker correctly loads only the delta of the KV cache when a prefix is already computed by vLLM. This test simulates a scenario where vLLM has already computed a certain - number of tokens (prefix) and the TPUConnectorWorker needs to load + number of tokens (prefix) and the TPUOffloadConnectorWorker needs to load only the remaining "delta" of the KV cache from the CPU backend. Steps: 1. Setup: - Create a device mesh and sharding configurations. - - Instantiate a TPUConnector with a worker role. + - Instantiate a TPUOffloadConnector with a worker role. - Create mock source (ground truth) and destination KV caches on the TPU. - Register a mock TPUModelRunner with the worker. @@ -802,7 +804,7 @@ def test_tpu_connector_load( 3. Prepare and Execute Delta Load: - Calculate the number of tokens to load (the delta). - - Construct the necessary metadata (`TPUConnectorMetadata`) and `LoadSpec` + - Construct the necessary metadata (`TPUOffloadConnectorMetadata`) and `LoadSpec` to trigger a delta load operation, skipping the already computed tokens. - Bind this metadata to the connector and call the worker's `start_load_kv` method to perform the host-to-device (h2d) load for the delta tokens. @@ -873,7 +875,8 @@ def test_tpu_connector_load( local_block_ids=local_block_ids, save_spec=save_spec, ) - connector_metadata = TPUConnectorMetadata(requests_meta=[req_meta]) + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) connector.bind_connector_metadata(connector_metadata) worker.wait_for_save() logger.info( @@ -911,7 +914,8 @@ def test_tpu_connector_load( local_block_ids=local_block_ids, load_spec=load_spec, ) - connector_metadata = TPUConnectorMetadata(requests_meta=[req_meta]) + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) connector.bind_connector_metadata(connector_metadata) logger.info("Connector metadata bound, calling start_load_kv.") worker.start_load_kv(fwd_ctx=None) diff --git a/tests/distributed/host_offloading_accuracy_test.py b/tests/distributed/host_offloading_accuracy_test.py index cb7ace20f..1793dfe74 100644 --- a/tests/distributed/host_offloading_accuracy_test.py +++ b/tests/distributed/host_offloading_accuracy_test.py @@ -35,9 +35,9 @@ def sampling_config(): @pytest.fixture def kv_transfer_config(): - """use TPUConnector from tpu_connector_local""" + """use TPUOffloadConnector from tpu_connector_local""" return KVTransferConfig( - kv_connector="TPUConnector", + kv_connector="TPUOffloadConnector", kv_role="kv_both", kv_connector_module_path= "tpu_inference.distributed.tpu_connector_local", diff --git a/tests/distributed/host_offloading_precompile_test.py b/tests/distributed/host_offloading_precompile_test.py index 08552b4b5..aa3faaf75 100644 --- a/tests/distributed/host_offloading_precompile_test.py +++ b/tests/distributed/host_offloading_precompile_test.py @@ -13,9 +13,9 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole -from tpu_inference.distributed.local_cpu_backend import LocalCPUBackend -from tpu_inference.distributed.tpu_connector_local import \ - TPUConnector as CPUOffloadingConnector +from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend +from tpu_inference.distributed.offload.tpu_offload_connector import \ + TPUOffloadConnector as CPUOffloadingConnector from tpu_inference.logger import init_logger from tpu_inference.runner.tpu_jax_runner import TPUModelRunner diff --git a/tests/distributed/local_cpu_backend_test.py b/tests/distributed/local_cpu_backend_test.py index d656e949b..4d6323528 100644 --- a/tests/distributed/local_cpu_backend_test.py +++ b/tests/distributed/local_cpu_backend_test.py @@ -4,8 +4,8 @@ import pytest -from tpu_inference.distributed.cache_util import CacheKey -from tpu_inference.distributed.local_cpu_backend import LocalCPUBackend +from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend +from tpu_inference.distributed.offload.utils import CacheKey # Helper to create a mock value with a specific size diff --git a/tpu_inference/distributed/offload/__init__.py b/tpu_inference/distributed/offload/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tpu_inference/distributed/local_cpu_backend.py b/tpu_inference/distributed/offload/cpu_backend.py similarity index 98% rename from tpu_inference/distributed/local_cpu_backend.py rename to tpu_inference/distributed/offload/cpu_backend.py index 49a74623a..05479004f 100644 --- a/tpu_inference/distributed/local_cpu_backend.py +++ b/tpu_inference/distributed/offload/cpu_backend.py @@ -6,13 +6,13 @@ from collections import OrderedDict from typing import Any, Optional +from tpu_inference.distributed.offload.utils import CpuChunkId from tpu_inference.logger import init_logger logger = init_logger(__name__) GB = 1024**3 DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB -CpuChunkId = int # TODO(jcgu): creating independent cpu backends since scheduler & worker could be in different processes. diff --git a/tpu_inference/distributed/cpu_chunk_manager.py b/tpu_inference/distributed/offload/offload_manager.py similarity index 53% rename from tpu_inference/distributed/cpu_chunk_manager.py rename to tpu_inference/distributed/offload/offload_manager.py index ad36d20c8..cd7260f3d 100644 --- a/tpu_inference/distributed/cpu_chunk_manager.py +++ b/tpu_inference/distributed/offload/offload_manager.py @@ -3,10 +3,11 @@ from collections import OrderedDict from dataclasses import dataclass -from typing import Literal, Tuple +from typing import Literal, Optional, Tuple from vllm.v1.core.kv_cache_utils import BlockHash +from tpu_inference.distributed.offload.utils import CpuChunkId, ReqId from tpu_inference.logger import init_logger logger = init_logger(__name__) @@ -19,7 +20,7 @@ @dataclass class CPUChunk: - chunk_id: int + chunk_id: CpuChunkId ref_cnt: int = -1 _chunk_hash: ChunkHash | None = None @@ -59,7 +60,7 @@ def __init__(self, num_chunks: int): CPUChunk(idx) for idx in range(num_chunks - 1, -1, -1) ] # {allocated_chunk_id: chunk_hash} - self.allocated_id_to_hash_map: dict[int, ChunkHash] = {} + self.allocated_id_to_hash_map: dict[CpuChunkId, ChunkHash] = {} @property def num_free_chunks(self): @@ -96,7 +97,7 @@ def release_chunks(self, chunks: list[CPUChunk]): self._num_allocated_chunks -= len(chunks) -class LRUOffloadingManager: +class LRUCacheManager: def __init__(self, num_cpu_chunks: int): self.num_chunks = num_cpu_chunks @@ -221,3 +222,147 @@ def mark_completion(self, chunk_ids, operation: Literal['save', self.complete_load(chunk_hashes) else: raise ValueError(f"Unknown operation: {operation}") + + +class StagingBufferManager(): + """ Bookkeeping the staging buffer inside the connector scheduler. + NOTE(jcgu): the operations (e.g., allocate, free, get) to staging buffer / blocks are NOT thread-safe. + But it's okay since there is only one connector scheduler instance. + """ + + def __init__(self, num_blocks: int): + self.num_blocks = num_blocks + # {req_id: list(num_occupied_staging_blocks)} + self._blocks_for_save: dict[ReqId, int] = {} + self._blocks_for_load: dict[ReqId, int] = {} + + self._num_free_blocks: int = self.num_blocks + # keep track of the total occupied staging blocks for save and load respectively + self._num_blocks_for_save: int = 0 + self._num_blocks_for_load: int = 0 + + def get_num_free_staging_blocks(self) -> int: + return self._num_free_blocks + + def get_num_used_staging_blocks(self) -> int: + return self._num_blocks_for_load + self._num_blocks_for_save + + def get_num_used_save_staging_blocks(self, req_id: ReqId) -> int: + return self._blocks_for_save.get(req_id, 0) + + def get_num_used_load_staging_blocks(self, req_id: ReqId) -> int: + return self._blocks_for_load.get(req_id, 0) + + def allocate(self, req_id: ReqId, num_blocks: int, + usage: Literal["load", "save"]) -> int: + if num_blocks < 0: + logger.warning( + f" get {num_blocks} staging blocks to allocate for Req:{req_id}." + ) + return num_blocks + if num_blocks > self._num_free_blocks: + # do not have enough capacity, return 0 + return 0 + + if usage == "load": + if req_id in self._blocks_for_load: + # NOTE(jcgu): before completing the previous load, new load + # should not be triggered for the same request (is this correct?) + raise ValueError( + f" Req({req_id}) already has {self._blocks_for_load[req_id]}, and should not have new loads." + ) + else: + self._blocks_for_load[req_id] = num_blocks + self._num_blocks_for_load += num_blocks + elif usage == "save": + if req_id in self._blocks_for_save: + self._blocks_for_save[req_id] += num_blocks + else: + self._blocks_for_save[req_id] = num_blocks + self._num_blocks_for_save += num_blocks + else: + raise ValueError( + f" Staging buffer manager should not get usage: {usage}") + self._num_free_blocks -= num_blocks + + logger.info( + f" allocate {num_blocks} staging blocks to Req:{req_id} for {usage}." + ) + return num_blocks + + def free(self, + req_id: ReqId, + usage: Literal["load", "save"], + num_finished_blocks: Optional[int] = None) -> int: + """ + when num_finished_blocks is not given, we will assume the request is finished and should be removed. + """ + num_freed_blocks = 0 + # NOTE(jcgu): assuming FIFO execution order for a single request's save and + # load operations respectively + if usage == "load": + if req_id not in self._blocks_for_load: + logger.warning( + f" there is no record of staging buffer (usage: {usage}) for Req:{req_id}" + ) + return 0 + if num_finished_blocks is None: + num_freed_blocks = self._blocks_for_load[req_id] + else: + num_freed_blocks = num_finished_blocks + if self._blocks_for_load[req_id] < num_freed_blocks: + logger.warning( + f" Req({req_id}) has {num_finished_blocks} load staging buffer to free, but only has {self._blocks_for_load[req_id]} on record." + ) + + self._blocks_for_load[req_id] -= num_freed_blocks + if self._blocks_for_load[req_id] <= 0: + del self._blocks_for_load[req_id] + self._num_blocks_for_load -= num_freed_blocks + elif usage == "save": + if req_id not in self._blocks_for_save: + logger.warning( + f" there is no record of staging buffer (usage: {usage}) for Req:{req_id}" + ) + return 0 + if num_finished_blocks is None: + num_freed_blocks = self._blocks_for_save[req_id] + else: + num_freed_blocks = num_finished_blocks + if self._blocks_for_save[req_id] < num_freed_blocks: + logger.warning( + f" Req({req_id}) has {num_finished_blocks} save staging buffer to free, but only has {self._blocks_for_save[req_id]} on record." + ) + + self._blocks_for_save[req_id] -= num_freed_blocks + if self._blocks_for_save[req_id] <= 0: + del self._blocks_for_save[req_id] + self._num_blocks_for_save -= num_freed_blocks + else: + raise ValueError( + f" Staging buffer manager should not get usage: {usage}") + self._num_free_blocks += num_freed_blocks + + logger.info( + f" free {num_freed_blocks} staging blocks (usage: {usage}) from Req:{req_id}" + ) + return num_freed_blocks + + def get_usage(self, with_details: bool = False): + usage_str = (f"Staging Buffer: total={self.num_blocks}, " + f"free={self._num_free_blocks}, " + f"used_for_load={self._num_blocks_for_load}, " + f"used_for_save={self._num_blocks_for_save};") + if with_details: + blocks_for_save_str = " save_details:{" + for req, bn in self._blocks_for_save.items(): + blocks_for_save_str += f"{req}:{bn}," + blocks_for_save_str += "} " + + blocks_for_load_str = " load_details:{" + for req, bn in self._blocks_for_load.items(): + blocks_for_load_str += f"{req}:{bn}," + blocks_for_load_str += "}." + usage_str += blocks_for_save_str + blocks_for_load_str + + return usage_str diff --git a/tpu_inference/distributed/tpu_connector_local.py b/tpu_inference/distributed/offload/tpu_offload_connector.py similarity index 90% rename from tpu_inference/distributed/tpu_connector_local.py rename to tpu_inference/distributed/offload/tpu_offload_connector.py index 20e662daa..078b7f0fa 100644 --- a/tpu_inference/distributed/tpu_connector_local.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Scheduler side execution: -TPUConnectorScheduler manages the state of KV cache loading and saving for +TPUOffloadConnectorScheduler manages the state of KV cache loading and saving for each request. It acts as a state machine, tracking the progress of requests across multiple scheduling steps and generating work orders (TPUReqMeta) for -the TPUConnectorWorker. +the TPUOffloadConnectorWorker. Core Components: - RequestTracker: The primary state object for a request. It tracks the @@ -75,7 +75,7 @@ can be safely freed. Worker Side Execution: -- The TPUConnectorWorker receives the `TPUConnectorMetadata` containing the list of +- The TPUOffloadConnectorWorker receives the `TPUOffloadConnectorMetadata` containing the list of `TPUReqMeta` objects. - `start_load_kv`: Iterates through the metadata. If a `meta.load_spec` exists, it reads the corresponding data from the CPU backend and copies it @@ -110,20 +110,17 @@ from vllm.v1.request import Request from vllm.forward_context import ForwardContext +from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend +from tpu_inference.distributed.offload.offload_manager import ( + LRUCacheManager, StagingBufferManager) +from tpu_inference.distributed.offload.utils import ( + CPU_OFFLOADING_SWAP_OP_TYPE, CpuChunkId, KVCacheSwapFn, ReqId, + TokenProcessor, cdiv, get_default_kv_connector_staging_buffer_tokens, + get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) from tpu_inference.logger import init_logger from tpu_inference.runner.kv_cache_manager import KVCacheManager from tpu_inference.runner.tpu_jax_runner import TPUModelRunner -from .cache_util import (CPU_OFFLOADING_SWAP_OP_TYPE, KVCacheSwapFn, - TokenProcessor, cdiv, - get_default_kv_connector_staging_buffer_tokens, - get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) -from .cpu_chunk_manager import LRUOffloadingManager -from .local_cpu_backend import LocalCPUBackend - -EngineId = str -ReqId = str - logger = init_logger(__name__) # kv cache layout needed by cpu offloading mechanism @@ -134,8 +131,6 @@ BLOCK_SIZE_BUCKETS = [1, 2, 4, 8, 16] -CPU_CHUNK_ID = int - # we keep our operations at vllm's block granularity, # and provide the following three preferences when handling # the last partial block during save: @@ -331,25 +326,26 @@ def num_finished_blocks(self) -> int: # The metadata used for communicating between scheduler and worker connectors. @dataclass -class TPUConnectorMetadata(KVConnectorMetadata): +class TPUOffloadConnectorMetadata(KVConnectorMetadata): requests_meta: list[TPUReqMeta] = field(default_factory=list) -class TPUConnector(KVConnectorBase_V1): +class TPUOffloadConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - logger.info("TPUConnector: Entering __init__") + logger.info("TPUOffloadConnector: Entering __init__") assert vllm_config.kv_transfer_config is not None if role == KVConnectorRole.SCHEDULER: self.connector_scheduler = \ - TPUConnectorScheduler(vllm_config) + TPUOffloadConnectorScheduler(vllm_config) self.connector_worker = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None # The worker needs a reference to the base connector to access # the metadata object set by the engine. - self.connector_worker = TPUConnectorWorker(vllm_config, self) + self.connector_worker = TPUOffloadConnectorWorker( + vllm_config, self) ############################################################ # Class Methods @@ -368,7 +364,7 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig): return None logger.info_once( - "TPUConnector currently only supports %s KV cache layout.", + "TPUOffloadConnector currently only supports %s KV cache layout.", REQUIRED_KV_CACHE_LAYOUT) return REQUIRED_KV_CACHE_LAYOUT @@ -400,7 +396,7 @@ def update_state_after_alloc(self, request: "Request", def build_connector_meta( self, scheduler_output: SchedulerOutput, - ) -> TPUConnectorMetadata: + ) -> TPUOffloadConnectorMetadata: assert self.connector_scheduler is not None return self.connector_scheduler.build_connector_meta(scheduler_output) @@ -416,7 +412,7 @@ def request_finished( # Worker Side Methods ############################################################ def register_kv_caches(self, kv_caches: list[jax.Array]): - logger.info("TPUConnector: Entering register_kv_caches") + logger.info("TPUOffloadConnector: Entering register_kv_caches") """ We don't register kv_caches in connector, we call `register_runner` and use runner.kv_caches directly instead because the ref of runner.kv_caches @@ -425,7 +421,7 @@ def register_kv_caches(self, kv_caches: list[jax.Array]): pass def register_runner(self, runner: TPUModelRunner) -> None: - logger.info("TPUConnector: Entering register_runner") + logger.info("TPUOffloadConnector: Entering register_runner") assert self.connector_worker is not None self.connector_worker.register_runner(runner) @@ -435,17 +431,18 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: self.connector_worker.start_load_kv(fwd_ctx) def wait_for_layer_load(self, layer_name: str) -> None: - logger.info("TPUConnector: Entering wait_for_layer_load") + logger.info("TPUOffloadConnector: Entering wait_for_layer_load") """TPU connector doesn't support layer wise load.""" pass def save_kv_layer(self, **kwargs) -> None: - logger.info("TPUConnector: Entering save_kv_layer") + logger.info("TPUOffloadConnector: Entering save_kv_layer") """TPU connector doesn't support layer wise save.""" pass def wait_for_save(self): - assert isinstance(self._connector_metadata, TPUConnectorMetadata) + assert isinstance(self._connector_metadata, + TPUOffloadConnectorMetadata) self.connector_worker.wait_for_save() def get_finished(self, @@ -463,160 +460,16 @@ def get_kv_connector_stats(self) -> KVConnectorStats | None: return self.connector_worker.get_kv_connector_stats() -class _StagingBufferManager(): - """ Bookkeeping the staging buffer inside the connector scheduler. - NOTE(jcgu): the operations (e.g., allocate, free, get) to staging buffer / blocks are NOT thread-safe. - But it's okay since there is only one connector scheduler instance. - """ - - def __init__(self, num_blocks: int): - self.num_blocks = num_blocks - # {req_id: list(num_occupied_staging_blocks)} - self._blocks_for_save: dict[ReqId, int] = {} - self._blocks_for_load: dict[ReqId, int] = {} - - self._num_free_blocks: int = self.num_blocks - # keep track of the total occupied staging blocks for save and load respectively - self._num_blocks_for_save: int = 0 - self._num_blocks_for_load: int = 0 - - def get_num_free_staging_blocks(self) -> int: - return self._num_free_blocks - - def get_num_used_staging_blocks(self) -> int: - return self._num_blocks_for_load + self._num_blocks_for_save - - def get_num_used_save_staging_blocks(self, req_id: ReqId) -> int: - return self._blocks_for_save.get(req_id, 0) - - def get_num_used_load_staging_blocks(self, req_id: ReqId) -> int: - return self._blocks_for_load.get(req_id, 0) - - def allocate(self, req_id: ReqId, num_blocks: int, - usage: Literal["load", "save"]) -> int: - if num_blocks < 0: - logger.warning( - f" get {num_blocks} staging blocks to allocate for Req:{req_id}." - ) - return num_blocks - if num_blocks > self._num_free_blocks: - # do not have enough capacity, return 0 - return 0 - - if usage == "load": - if req_id in self._blocks_for_load: - # NOTE(jcgu): before completing the previous load, new load - # should not be triggered for the same request (is this correct?) - raise ValueError( - f" Req({req_id}) already has {self._blocks_for_load[req_id]}, and should not have new loads." - ) - else: - self._blocks_for_load[req_id] = num_blocks - self._num_blocks_for_load += num_blocks - elif usage == "save": - if req_id in self._blocks_for_save: - self._blocks_for_save[req_id] += num_blocks - else: - self._blocks_for_save[req_id] = num_blocks - self._num_blocks_for_save += num_blocks - else: - raise ValueError( - f" Staging buffer manager should not get usage: {usage}") - self._num_free_blocks -= num_blocks - - logger.info( - f" allocate {num_blocks} staging blocks to Req:{req_id} for {usage}." - ) - return num_blocks - - def free(self, - req_id: ReqId, - usage: Literal["load", "save"], - num_finished_blocks: Optional[int] = None) -> int: - """ - when num_finished_blocks is not given, we will assume the request is finished and should be removed. - """ - num_freed_blocks = 0 - # NOTE(jcgu): assuming FIFO execution order for a single request's save and - # load operations respectively - if usage == "load": - if req_id not in self._blocks_for_load: - logger.warning( - f" there is no record of staging buffer (usage: {usage}) for Req:{req_id}" - ) - return 0 - if num_finished_blocks is None: - num_freed_blocks = self._blocks_for_load[req_id] - else: - num_freed_blocks = num_finished_blocks - if self._blocks_for_load[req_id] < num_freed_blocks: - logger.warning( - f" Req({req_id}) has {num_finished_blocks} load staging buffer to free, but only has {self._blocks_for_load[req_id]} on record." - ) - - self._blocks_for_load[req_id] -= num_freed_blocks - if self._blocks_for_load[req_id] <= 0: - del self._blocks_for_load[req_id] - self._num_blocks_for_load -= num_freed_blocks - elif usage == "save": - if req_id not in self._blocks_for_save: - logger.warning( - f" there is no record of staging buffer (usage: {usage}) for Req:{req_id}" - ) - return 0 - if num_finished_blocks is None: - num_freed_blocks = self._blocks_for_save[req_id] - else: - num_freed_blocks = num_finished_blocks - if self._blocks_for_save[req_id] < num_freed_blocks: - logger.warning( - f" Req({req_id}) has {num_finished_blocks} save staging buffer to free, but only has {self._blocks_for_save[req_id]} on record." - ) - - self._blocks_for_save[req_id] -= num_freed_blocks - if self._blocks_for_save[req_id] <= 0: - del self._blocks_for_save[req_id] - self._num_blocks_for_save -= num_freed_blocks - else: - raise ValueError( - f" Staging buffer manager should not get usage: {usage}") - self._num_free_blocks += num_freed_blocks - - logger.info( - f" free {num_freed_blocks} staging blocks (usage: {usage}) from Req:{req_id}" - ) - return num_freed_blocks - - def get_usage(self, with_details: bool = False): - usage_str = (f"Staging Buffer: total={self.num_blocks}, " - f"free={self._num_free_blocks}, " - f"used_for_load={self._num_blocks_for_load}, " - f"used_for_save={self._num_blocks_for_save};") - if with_details: - blocks_for_save_str = " save_details:{" - for req, bn in self._blocks_for_save.items(): - blocks_for_save_str += f"{req}:{bn}," - blocks_for_save_str += "} " - - blocks_for_load_str = " load_details:{" - for req, bn in self._blocks_for_load.items(): - blocks_for_load_str += f"{req}:{bn}," - blocks_for_load_str += "}." - usage_str += blocks_for_save_str + blocks_for_load_str - - return usage_str - - -class TPUConnectorScheduler(): +class TPUOffloadConnectorScheduler(): def __init__(self, vllm_config: "VllmConfig"): - logger.info("TPUConnectorScheduler: Entering __init__") + logger.info("TPUOffloadConnectorScheduler: Entering __init__") self.vllm_config = vllm_config self.config = vllm_config.kv_transfer_config self.block_size = vllm_config.cache_config.block_size # offloading manager - self.offload_manager = LRUOffloadingManager(num_cpu_chunks=1024) + self.offload_manager = LRUCacheManager(num_cpu_chunks=1024) self._request_trackers: dict[ReqId, RequestTracker] = {} # This dictionary holds the full vLLM Request object for all requests @@ -631,8 +484,8 @@ def __init__(self, vllm_config: "VllmConfig"): self._external_cache_hits: dict[ReqId, int] = {} # request ID -> set(block hashes being saved/loaded) - self._reqs_being_saved = defaultdict[ReqId, set[CPU_CHUNK_ID]](set) - self._reqs_being_loaded = defaultdict[ReqId, set[CPU_CHUNK_ID]](set) + self._reqs_being_saved = defaultdict[ReqId, set[CpuChunkId]](set) + self._reqs_being_loaded = defaultdict[ReqId, set[CpuChunkId]](set) model_name = self.vllm_config.model_config.model self.token_processor = TokenProcessor(model_name=model_name, @@ -668,11 +521,11 @@ def __init__(self, vllm_config: "VllmConfig"): os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS", str(_default_staging_buffer_tokens))) self.num_staging_blocks = num_staging_buffer_tokens // self.block_size - self.staging_buffer_manager = _StagingBufferManager( + self.staging_buffer_manager = StagingBufferManager( num_blocks=self.num_staging_blocks) logger.info( - f"TPUConnectorScheduler initialized with: " + f"TPUOffloadConnectorScheduler initialized with: " f"block_size={self.block_size}, " f"cpu_chunk_size={self.cpu_chunk_size}, " f"model_name={model_name}, " @@ -834,7 +687,7 @@ def update_state_after_alloc(self, request: "Request", Update the dst_blocks in the load_spec """ logger.info( - f"TPUConnectorScheduler: Entering update_state_after_alloc Request {request.request_id}: Scheduler allocated " + f"TPUOffloadConnectorScheduler: Entering update_state_after_alloc Request {request.request_id}: Scheduler allocated " f"{num_external_tokens} external tokens.") self._unfinished_requests[request.request_id] = request if num_external_tokens == 0: @@ -934,7 +787,7 @@ def _prepare_req_meta( f"should_save={should_save}") # A SaveSpec is always prepared for a finished request to signal completion, - # even if we don't save the underlying KV data. This is to ensure the TPUConnectorWorker + # even if we don't save the underlying KV data. This is to ensure the TPUOffloadConnectorWorker # can correctly report finished request. save_spec = None if should_save: @@ -1045,8 +898,9 @@ def _prepare_req_meta( ) def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> TPUConnectorMetadata: - metadata = TPUConnectorMetadata() + self, + scheduler_output: SchedulerOutput) -> TPUOffloadConnectorMetadata: + metadata = TPUOffloadConnectorMetadata() # Phase 1: Handle and clean up finished requests # This block handles requests that have completed their generation. @@ -1230,7 +1084,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput): connectors output. """ logger.info( - f"TPUConnectorScheduler: getting workers' output: finished_sending: {connector_output.finished_sending}, finished_recving: {connector_output.finished_recving}" + f"TPUOffloadConnectorScheduler: getting workers' output: finished_sending: {connector_output.finished_sending}, finished_recving: {connector_output.finished_recving}" ) # per iteration, update the finished staging blocks @@ -1314,7 +1168,7 @@ def request_finished( return: delay_free_blocks, kv_xfer_params """ - logger.info("TPUConnectorScheduler: Entering request_finished") + logger.info("TPUOffloadConnectorScheduler: Entering request_finished") # Return True to indicate the request is being saved asynchronously # and its blocks should not be freed yet. @@ -1326,17 +1180,19 @@ def request_finished( self._reqs_being_loaded[req_id]) > 0: return True, None - logger.info(f"TPUConnectorScheduler: finished request: {req_id}") + logger.info( + f"TPUOffloadConnectorScheduler: finished request: {req_id}") self._reqs_being_saved.pop(req_id, None) self._reqs_being_loaded.pop(req_id, None) return False, None -class TPUConnectorWorker: +class TPUOffloadConnectorWorker: - def __init__(self, vllm_config: VllmConfig, connector: "TPUConnector"): - logger.info("TPUConnectorWorker: Entering __init__") + def __init__(self, vllm_config: VllmConfig, + connector: "TPUOffloadConnector"): + logger.info("TPUOffloadConnectorWorker: Entering __init__") self.vllm_config = vllm_config self.config = vllm_config.kv_transfer_config self.connector = connector @@ -1382,11 +1238,11 @@ def __init__(self, vllm_config: VllmConfig, connector: "TPUConnector"): self.offload_stats = KVOffloadConnectorStats() def __del__(self): - logger.info("TPUConnectorWorker: Entering __del__") + logger.info("TPUOffloadConnectorWorker: Entering __del__") self.save_executor.shutdown(wait=True) def register_runner(self, runner: TPUModelRunner): - logger.info("TPUConnectorWorker: Entering register_runner") + logger.info("TPUOffloadConnectorWorker: Entering register_runner") self.runner = runner self.devices = runner.devices self.mesh = runner.mesh @@ -1416,7 +1272,8 @@ def register_runner(self, runner: TPUModelRunner): host_sharding=self.flatten_host_sharding, device_sharding=self.flatten_device_sharding) - logger.info("KV Cache details registered in TPUConnectorWorker:") + logger.info( + "KV Cache details registered in TPUOffloadConnectorWorker:") logger.info(f" - Num layers: {self.num_layers}") logger.info(f" - Shape per layer: {self.shape}") logger.info(f" - DType: {self.dtype}") @@ -1426,7 +1283,7 @@ def register_runner(self, runner: TPUModelRunner): logger.info(f" - Layout: {self.kv_cache_layout}") else: raise ValueError( - "TPUConnectorWorker registered with no KV caches.") + "TPUOffloadConnectorWorker registered with no KV caches.") # Pre-compile the JIT functions for KV cache swapping. if self.use_bucketed_swap_ops: @@ -1818,11 +1675,12 @@ def wait_for_save(self): if self._processed_save_for_step: return - # logger.info("TPUConnectorWorker: Entering wait_for_save") + # logger.info("TPUOffloadConnectorWorker: Entering wait_for_save") metadata = self.connector._get_connector_metadata() - if not isinstance(metadata, TPUConnectorMetadata): + if not isinstance(metadata, TPUOffloadConnectorMetadata): logger.info( - "wait_for_save:not an instances of TPUConnectorMetadata") + "wait_for_save:not an instances of TPUOffloadConnectorMetadata" + ) self._processed_save_for_step = True return @@ -1899,8 +1757,9 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: # Reset the save processing flag at the start of a new step. self._processed_save_for_step = False metadata = self.connector._get_connector_metadata() - if not isinstance(metadata, - TPUConnectorMetadata) or not metadata.requests_meta: + if not isinstance( + metadata, + TPUOffloadConnectorMetadata) or not metadata.requests_meta: logger.info("No load operations scheduled for this step.") return @@ -1918,7 +1777,8 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: continue request_load_start_time = time.time() - logger.info("TPUConnectorWorker: Starting KV cache load process.") + logger.info( + "TPUOffloadConnectorWorker: Starting KV cache load process.") dst_blocks = meta.load_spec.dst_blocks src_chunks = meta.load_spec.src_chunks num_blocks_to_load = len(dst_blocks) @@ -2012,7 +1872,7 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: if load_times: aggregate_load_time = sum(load_times) logger.info( - f"TPUConnectorWorker: Aggregate KV cache load time for {len(load_times)} requests: {aggregate_load_time:.4f} seconds" + f"TPUOffloadConnectorWorker: Aggregate KV cache load time for {len(load_times)} requests: {aggregate_load_time:.4f} seconds" ) def get_kv_connector_stats(self) -> KVConnectorStats | None: @@ -2038,7 +1898,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: # request IDs are correctly identified and reported back to the engine # for resource cleanup. The `wait_for_save` method is idempotent, # so this call is a no-op in the normal execution path. - logger.info("TPUConnectorWorker: Entering get_finished") + logger.info("TPUOffloadConnectorWorker: Entering get_finished") self.wait_for_save() finished_saves = self.finished_save_reqs diff --git a/tpu_inference/distributed/cache_util.py b/tpu_inference/distributed/offload/utils.py similarity index 99% rename from tpu_inference/distributed/cache_util.py rename to tpu_inference/distributed/offload/utils.py index 46d0111ef..5f28a26f3 100644 --- a/tpu_inference/distributed/cache_util.py +++ b/tpu_inference/distributed/offload/utils.py @@ -14,6 +14,10 @@ from tpu_inference.kernels.dma.host_dma import d2h_dma, h2d_dma from tpu_inference.logger import init_logger +ReqId = str + +CpuChunkId = int + # Corresponds to the initial hash value NONE_HASH = 0 diff --git a/tpu_inference/platforms/tpu_jax.py b/tpu_inference/platforms/tpu_jax.py index b939d8eed..2d3095579 100644 --- a/tpu_inference/platforms/tpu_jax.py +++ b/tpu_inference/platforms/tpu_jax.py @@ -204,9 +204,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Forcing --disable_chunked_mm_input.") scheduler_config.disable_chunked_mm_input = True - kv_transfer_config = vllm_config.kv_transfer_config - if kv_transfer_config is not None: - assert kv_transfer_config.kv_connector == "TPUConnector" + # NOTE(jcgu): not needed + # kv_transfer_config = vllm_config.kv_transfer_config + # if kv_transfer_config is not None: + # assert kv_transfer_config.kv_connector == "TPUConnector" update_vllm_config_for_qwix_quantization(vllm_config) diff --git a/tpu_inference/runner/kv_cache_manager.py b/tpu_inference/runner/kv_cache_manager.py index 1af7861f5..247ac44dc 100644 --- a/tpu_inference/runner/kv_cache_manager.py +++ b/tpu_inference/runner/kv_cache_manager.py @@ -16,7 +16,8 @@ from tpu_inference import utils from tpu_inference import utils as common_utils -from tpu_inference.distributed.cache_util import get_kv_connector_cache_layout +from tpu_inference.distributed.offload.utils import \ + get_kv_connector_cache_layout from tpu_inference.logger import init_logger from tpu_inference.runner import utils as runner_utils from tpu_inference.runner.input_batch_jax import CachedRequestState, InputBatch diff --git a/tpu_inference/worker/tpu_worker_jax.py b/tpu_inference/worker/tpu_worker_jax.py index 22d70212d..b3ae10fbd 100644 --- a/tpu_inference/worker/tpu_worker_jax.py +++ b/tpu_inference/worker/tpu_worker_jax.py @@ -26,7 +26,7 @@ AbstractLoRARequest, AbstractSchedulerOutput) from tpu_inference.di.interfaces import HostInterface -from tpu_inference.distributed.cache_util import \ +from tpu_inference.distributed.offload.utils import \ get_default_kv_connector_staging_buffer_tokens from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port, get_node_id) @@ -171,7 +171,7 @@ def determine_available_memory(self) -> int: if self.vllm_config.kv_transfer_config is not None: kv_transfer_config = self.vllm_config.kv_transfer_config - if kv_transfer_config.kv_connector == "TPUConnector" and kv_transfer_config.kv_connector_module_path == "tpu_inference.distributed.tpu_connector_local": + if kv_transfer_config.kv_connector == "TPUOffloadConnector" and kv_transfer_config.kv_connector_module_path == "tpu_inference.distributed.offload.tpu_offload_connector": # If kv offloading is enabled, we need to account for the memory used by the KV transfer buffer. _default_staging_buffer_tokens = get_default_kv_connector_staging_buffer_tokens( ) From 5773e8ca5dcfc9bb780e44d3f46121859db08702 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Tue, 18 Nov 2025 07:50:35 +0000 Subject: [PATCH 03/19] move offload tests Signed-off-by: Juncheng Gu --- examples/gke/pod_tpu_host_offload_unit_tests.yaml | 12 ++++++------ .../{ => offload}/cpu_offloading_cache_util_test.py | 0 .../{ => offload}/cpu_offloading_scheduler_test.py | 0 .../{ => offload}/cpu_offloading_worker_test.py | 0 .../{ => offload}/host_offloading_accuracy_test.py | 0 .../{ => offload}/host_offloading_precompile_test.py | 0 .../{ => offload}/local_cpu_backend_test.py | 0 7 files changed, 6 insertions(+), 6 deletions(-) rename tests/distributed/{ => offload}/cpu_offloading_cache_util_test.py (100%) rename tests/distributed/{ => offload}/cpu_offloading_scheduler_test.py (100%) rename tests/distributed/{ => offload}/cpu_offloading_worker_test.py (100%) rename tests/distributed/{ => offload}/host_offloading_accuracy_test.py (100%) rename tests/distributed/{ => offload}/host_offloading_precompile_test.py (100%) rename tests/distributed/{ => offload}/local_cpu_backend_test.py (100%) diff --git a/examples/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/gke/pod_tpu_host_offload_unit_tests.yaml index ba81c654c..5a9fb73db 100644 --- a/examples/gke/pod_tpu_host_offload_unit_tests.yaml +++ b/examples/gke/pod_tpu_host_offload_unit_tests.yaml @@ -17,12 +17,12 @@ spec: command: - /bin/bash - -c - - "pytest -sv tests/distributed/host_offloading_precompile_test.py" - # - "pytest -sv tests/distributed/cpu_offloading_worker_test.py" - # - "pytest -sv tests/distributed/cpu_offloading_cache_util_test.py" - # - "pytest -sv tests/distributed/host_offloading_accuracy_test.py" - # - "pytest -sv tests/distributed/local_cpu_backend_test.py" - # - "pytest -sv tests/distributed/host_offloading_precompile_test.py" + - "pytest -sv tests/distributed/offload/host_offloading_precompile_test.py" + # - "pytest -sv tests/distributed/offload/cpu_offloading_worker_test.py" + # - "pytest -sv tests/distributed/offload/cpu_offloading_cache_util_test.py" + # - "pytest -sv tests/distributed/offload/host_offloading_accuracy_test.py" + # - "pytest -sv tests/distributed/offload/local_cpu_backend_test.py" + # - "pytest -sv tests/distributed/offload/host_offloading_precompile_test.py" env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/tests/distributed/cpu_offloading_cache_util_test.py b/tests/distributed/offload/cpu_offloading_cache_util_test.py similarity index 100% rename from tests/distributed/cpu_offloading_cache_util_test.py rename to tests/distributed/offload/cpu_offloading_cache_util_test.py diff --git a/tests/distributed/cpu_offloading_scheduler_test.py b/tests/distributed/offload/cpu_offloading_scheduler_test.py similarity index 100% rename from tests/distributed/cpu_offloading_scheduler_test.py rename to tests/distributed/offload/cpu_offloading_scheduler_test.py diff --git a/tests/distributed/cpu_offloading_worker_test.py b/tests/distributed/offload/cpu_offloading_worker_test.py similarity index 100% rename from tests/distributed/cpu_offloading_worker_test.py rename to tests/distributed/offload/cpu_offloading_worker_test.py diff --git a/tests/distributed/host_offloading_accuracy_test.py b/tests/distributed/offload/host_offloading_accuracy_test.py similarity index 100% rename from tests/distributed/host_offloading_accuracy_test.py rename to tests/distributed/offload/host_offloading_accuracy_test.py diff --git a/tests/distributed/host_offloading_precompile_test.py b/tests/distributed/offload/host_offloading_precompile_test.py similarity index 100% rename from tests/distributed/host_offloading_precompile_test.py rename to tests/distributed/offload/host_offloading_precompile_test.py diff --git a/tests/distributed/local_cpu_backend_test.py b/tests/distributed/offload/local_cpu_backend_test.py similarity index 100% rename from tests/distributed/local_cpu_backend_test.py rename to tests/distributed/offload/local_cpu_backend_test.py From b1ccae8a111f9b9b3c1b7dd25b04872535c8ed13 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 19 Nov 2025 22:54:06 +0000 Subject: [PATCH 04/19] update cpu_backend_test Signed-off-by: Juncheng Gu --- tests/distributed/offload/cpu_backend_test.py | 83 +++++ .../offload/local_cpu_backend_test.py | 300 ------------------ .../distributed/offload/cpu_backend.py | 49 +-- .../offload/tpu_offload_connector.py | 14 +- 4 files changed, 121 insertions(+), 325 deletions(-) create mode 100644 tests/distributed/offload/cpu_backend_test.py delete mode 100644 tests/distributed/offload/local_cpu_backend_test.py diff --git a/tests/distributed/offload/cpu_backend_test.py b/tests/distributed/offload/cpu_backend_test.py new file mode 100644 index 000000000..d74e4d2e1 --- /dev/null +++ b/tests/distributed/offload/cpu_backend_test.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +import pytest + +from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend +from tpu_inference.distributed.offload.utils import CpuChunkId + + +# Helper to create a mock jax array with a specific size in bytes +def create_mock_jax_array(size_in_bytes: int) -> MagicMock: + """Creates a mock object with an 'nbytes' attribute.""" + mock_value = MagicMock() + mock_value.nbytes = size_in_bytes + return mock_value + + +class TestLocalCPUBackend: + """Test suite for the LocalCPUBackend.""" + + def test_add_and_get(self): + """Verifies that a value can be added and then retrieved successfully.""" + backend = LocalCPUBackend(num_cpu_chunks=10) + key = CpuChunkId(0) + value = create_mock_jax_array(50) + + backend.add(key, value) + retrieved_value = backend.get(key) + + assert retrieved_value == value + assert backend.current_size_bytes == 50 + + # Test with a list of JAX arrays (mocked) + key_list = CpuChunkId(1) + value_list = [create_mock_jax_array(20), create_mock_jax_array(30)] + backend.add(key_list, value_list) + retrieved_list_value = backend.get(key_list) + + assert retrieved_list_value == value_list + assert backend.current_size_bytes == 50 + 20 + 30 + + assert backend.num_occupied_cpu_chunks == 2 + + def test_add_invalid_chunk_id(self): + """Verifies that adding a value with an invalid chunk_id raises a ValueError.""" + backend = LocalCPUBackend(num_cpu_chunks=10) + value = create_mock_jax_array(50) + + with pytest.raises(ValueError): + backend.add(CpuChunkId(-1), value) + + assert backend.num_occupied_cpu_chunks == 0 + + def test_reclaim_unoccupied_chunks(self): + """Tests that unoccupied chunks are reclaimed correctly.""" + backend = LocalCPUBackend(num_cpu_chunks=10) + key1 = CpuChunkId(0) + key2 = CpuChunkId(1) + key3 = CpuChunkId(2) + value = create_mock_jax_array(10) + + backend.add(key1, value) + backend.add(key2, value) + backend.add(key3, value) + + assert backend.current_size_bytes == 30 + assert len(backend.cache) == 3 + + # Reclaim one chunk + backend.reclaim_unoccupied_chunks(occupied_chunk_ids=[key1, key3]) + + assert backend.current_size_bytes == 20 + assert len(backend.cache) == 2 + assert key1 in backend.cache + assert key2 not in backend.cache + assert key3 in backend.cache + + # Reclaim all chunks + backend.reclaim_unoccupied_chunks(occupied_chunk_ids=[]) + + assert backend.current_size_bytes == 0 + assert len(backend.cache) == 0 diff --git a/tests/distributed/offload/local_cpu_backend_test.py b/tests/distributed/offload/local_cpu_backend_test.py deleted file mode 100644 index 4d6323528..000000000 --- a/tests/distributed/offload/local_cpu_backend_test.py +++ /dev/null @@ -1,300 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from unittest.mock import MagicMock - -import pytest - -from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend -from tpu_inference.distributed.offload.utils import CacheKey - - -# Helper to create a mock value with a specific size -def create_mock_value(size_in_bytes: int) -> MagicMock: - """Creates a mock object with an 'nbytes' attribute.""" - mock_value = MagicMock() - mock_value.nbytes = size_in_bytes - return mock_value - - -@pytest.fixture -def clean_backend_instance(): - """ - Provides a clean instance of the LocalCPUBackend for each test. - This is crucial because LocalCPUBackend is a singleton, and without - resetting its internal state, tests would interfere with each other. - """ - # Reset the singleton instance before each test. - # By setting LocalCPUBackend._instance to None, it forces the __new__ method - # to create a fresh, new object for every single test case, ensuring test isolation. - LocalCPUBackend._instance = None - LocalCPUBackend._initialized = False - yield - # Clean up after the test - LocalCPUBackend._instance = None - LocalCPUBackend._initialized = False - - -class TestLocalCPUBackend: - """Test suite for the LocalCPUBackend.""" - - def test_add_and_get(self, clean_backend_instance): - """Verifies that a value can be added and then retrieved successfully.""" - # Increased size to accommodate the list test without eviction - backend = LocalCPUBackend(max_cpu_cache_size_bytes=150) - key = CacheKey(model_name="test_model", chunk_hash="A") - value = create_mock_value(50) - - backend.add(key, value) - retrieved_value = backend.get(key) - - assert retrieved_value == value - assert backend.current_size_bytes == 50 - - # Test with a list of JAX arrays (mocked) - key_list = CacheKey(model_name="test_model", - chunk_hash="list_item_hash") - value_list = [create_mock_value(20), create_mock_value(30)] - backend.add(key_list, value_list) - retrieved_list_value = backend.get(key_list) - - assert retrieved_list_value == value_list - assert backend.current_size_bytes == 50 + 20 + 30 - - def test_get_updates_lru_order(self, clean_backend_instance): - """Tests that get() moves the accessed item to the end (most recent).""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key_a = CacheKey(model_name="test_model", chunk_hash="A") - key_b = CacheKey(model_name="test_model", chunk_hash="B") - value = create_mock_value(10) - - backend.add(key_a, value) - backend.add(key_b, value) - # Initial order: A, B - assert list(backend.cache.keys()) == [key_a, key_b] - - backend.get(key_a) - # Accessed A, so order should now be: B, A - assert list(backend.cache.keys()) == [key_b, key_a] - - def test_contains_updates_lru_order(self, clean_backend_instance): - """Tests that contains() moves the accessed item to the end.""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key_a = CacheKey(model_name="test_model", chunk_hash="A") - key_b = CacheKey(model_name="test_model", chunk_hash="B") - value = create_mock_value(10) - - backend.add(key_a, value) - backend.add(key_b, value) - # Initial order: A, B - assert list(backend.cache.keys()) == [key_a, key_b] - - backend.contains(key_a) - # Accessed A, so order should now be: B, A - assert list(backend.cache.keys()) == [key_b, key_a] - - def test_eviction_on_add(self, clean_backend_instance): - """Tests that the least recently used item is evicted when cache is full.""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key_a = CacheKey(model_name="test_model", chunk_hash="A") - key_b = CacheKey(model_name="test_model", chunk_hash="B") - key_c = CacheKey(model_name="test_model", chunk_hash="C") - value = create_mock_value(50) - - backend.add(key_a, value) # LRU - backend.add(key_b, value) # MRU - assert backend.current_size_bytes == 100 - assert key_a in backend.cache - assert key_b in backend.cache - - # This should evict key_a - backend.add(key_c, value) - - assert key_a not in backend.cache - assert key_b in backend.cache - assert key_c in backend.cache - assert backend.current_size_bytes == 100 - - def test_cannot_add_item_larger_than_capacity(self, - clean_backend_instance): - """Tests that an item larger than the cache's capacity is not added.""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key = CacheKey(model_name="test_model", chunk_hash="large_item_hash") - value = create_mock_value(101) - - backend.add(key, value) - - assert key not in backend.cache - assert backend.current_size_bytes == 0 - - def test_pin_on_hit(self, clean_backend_instance): - """Tests that using contains() with pin_on_hit=True pins the key.""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key = CacheKey(model_name="test_model", chunk_hash="A") - backend.add(key, create_mock_value(10)) - - assert key not in backend.pin_counts - backend.contains(key, pin_on_hit=True) - assert key in backend.pin_counts - assert backend.pin_counts[key] == 1 - - def test_pinned_item_is_not_evicted(self, clean_backend_instance): - """Tests that a pinned item is protected from eviction.""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key_a = CacheKey(model_name="test_model", chunk_hash="A") - key_b = CacheKey(model_name="test_model", chunk_hash="B") - key_c = CacheKey(model_name="test_model", chunk_hash="C") - value = create_mock_value(50) - - backend.add(key_a, value) - backend.add(key_b, value) - backend.contains(key_a, pin_on_hit=True) - assert key_a in backend.pin_counts - - # This should evict key_b, because key_a is pinned - backend.add(key_c, value) - - assert key_a in backend.cache - assert key_b not in backend.cache - assert key_c in backend.cache - assert backend.current_size_bytes == 100 - - def test_unpin_makes_item_evictable(self, clean_backend_instance): - """Tests that unpinning a key makes it eligible for eviction again.""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key_a = CacheKey(model_name="test_model", chunk_hash="A") - key_b = CacheKey(model_name="test_model", chunk_hash="B") - key_c = CacheKey(model_name="test_model", chunk_hash="C") - value = create_mock_value(50) - - backend.add(key_a, value) - backend.add(key_b, value) - backend.contains(key_a, - pin_on_hit=True) # Pin A, and make A most recent - assert list(backend.cache.keys()) == [key_b, key_a] - - # Unpin A, making it the LRU evictable item - backend.maybe_unpin_keys([key_a]) - assert key_a not in backend.pin_counts - - # This should now evict B - backend.add(key_c, value) - - assert key_b not in backend.cache - assert key_a in backend.cache - assert key_c in backend.cache - - def test_cache_full_of_pinned_items_prevents_add(clean_backend_instance): - """ - Tests that no new items can be added if the cache is full of - pinned items. - """ - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key_a = CacheKey(model_name="test_model", chunk_hash="A") - key_b = CacheKey(model_name="test_model", chunk_hash="B") - key_c = CacheKey(model_name="test_model", chunk_hash="C") - value = create_mock_value(50) - - backend.add(key_a, value) - backend.add(key_b, value) - - # Pin all items in the cache - backend.contains(key_a, pin_on_hit=True) - backend.contains(key_b, pin_on_hit=True) - - # Attempt to add a new item - backend.add(key_c, value) - - assert key_c not in backend.cache - assert key_a in backend.cache - assert key_b in backend.cache - assert backend.current_size_bytes == 100 - assert key_a in backend.pin_counts - assert key_b in backend.pin_counts - - def test_pinning_same_key_multiple_times_increments_count( - self, clean_backend_instance): - """Verifies that pinning an already-pinned key increments its count.""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key = CacheKey(model_name="test_model", chunk_hash="A") - backend.add(key, create_mock_value(10)) - - backend.contains(key, pin_on_hit=True) - assert backend.pin_counts[key] == 1 - - backend.contains(key, pin_on_hit=True) - assert backend.pin_counts[key] == 2 - - def test_unpin_decrements_count_and_removes_at_zero( - self, clean_backend_instance): - """Tests the core reference counting logic of the unpin_keys method.""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key = CacheKey(model_name="test_model", chunk_hash="A") - backend.add(key, create_mock_value(10)) - - # Pin twice - backend.contains(key, pin_on_hit=True) - backend.contains(key, pin_on_hit=True) - assert backend.pin_counts[key] == 2 - - # Unpin once - backend.maybe_unpin_keys([key]) - assert key in backend.pin_counts - assert backend.pin_counts[key] == 1 - - # Unpin again - backend.maybe_unpin_keys([key]) - assert key not in backend.pin_counts - - def test_item_with_positive_pin_count_is_not_evicted( - self, clean_backend_instance): - """ - Tests that an item with a pin count > 0 is not evicted, confirming - the race condition fix. - """ - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key_a = CacheKey(model_name="test_model", chunk_hash="A") - key_b = CacheKey(model_name="test_model", chunk_hash="B") - key_c = CacheKey(model_name="test_model", chunk_hash="C") - value = create_mock_value(50) - - backend.add(key_a, value) # Will be LRU - backend.add(key_b, value) - - # Pin key_a twice (simulating two requests) - backend.contains(key_a, pin_on_hit=True) - backend.contains(key_a, pin_on_hit=True) - - # Unpin key_a once (simulating one request finishing) - backend.maybe_unpin_keys([key_a]) - assert backend.pin_counts[key_a] == 1 - - # This add should trigger eviction of key_b, as key_a is still pinned. - backend.add(key_c, value) - - assert key_a in backend.cache - assert key_b not in backend.cache - assert key_c in backend.cache - assert key_a in backend.pin_counts - - def test_unpin_keys_returns_correct_counts(self, clean_backend_instance): - """Validates the meaningful return values of unpin_keys.""" - backend = LocalCPUBackend(max_cpu_cache_size_bytes=100) - key_a = CacheKey(model_name="test_model", chunk_hash="A") - key_b = CacheKey(model_name="test_model", chunk_hash="B") - value = create_mock_value(10) - - backend.add(key_a, value) - backend.add(key_b, value) - - # Pin A twice, B once - backend.contains(key_a, pin_on_hit=True) - backend.contains(key_a, pin_on_hit=True) - backend.contains(key_b, pin_on_hit=True) - - # Unpin both. A should be decremented, B should be fully unpinned. - unpinned_count, found_count = backend.maybe_unpin_keys([key_a, key_b]) - - assert found_count == 2 # Both keys were found in pin_counts - assert unpinned_count == 1 # Only key_b's count went to 0 - assert backend.pin_counts[key_a] == 1 - assert key_b not in backend.pin_counts diff --git a/tpu_inference/distributed/offload/cpu_backend.py b/tpu_inference/distributed/offload/cpu_backend.py index 05479004f..e94939d18 100644 --- a/tpu_inference/distributed/offload/cpu_backend.py +++ b/tpu_inference/distributed/offload/cpu_backend.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import sys from collections import OrderedDict from typing import Any, Optional @@ -15,7 +14,6 @@ DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB -# TODO(jcgu): creating independent cpu backends since scheduler & worker could be in different processes. class LocalCPUBackend: """ A singleton in-memory CPU backend for storing KV cache keys and values. @@ -29,18 +27,14 @@ class LocalCPUBackend: size limit and support for pinning cache entries to prevent eviction. """ - def __init__(self, - max_cpu_cache_size_bytes: int = DEFAULT_CPU_CACHE_SIZE_BYTES): - env_cache_size_gb = os.getenv("TPU_OFFLOAD_CPU_CACHE_SIZE_GB") - self.max_cpu_cache_size_bytes = (int(env_cache_size_gb) * - GB if env_cache_size_gb is not None - else max_cpu_cache_size_bytes) - - # The cache is an OrderedDict for LRU behavior. + def __init__(self, num_cpu_chunks: int): + self.max_num_cpu_chunks = num_cpu_chunks self.cache: OrderedDict[CpuChunkId, Any] = OrderedDict() self.current_size_bytes = 0 - logger.info("Singleton LocalCPUBackend initialized." - f"CPU cache size: {self.max_cpu_cache_size_bytes} bytes") + self.num_occupied_cpu_chunks = 0 + logger.info( + "LocalCPUBackend initialized." + f"CPU cache capacity: {self.max_num_cpu_chunks} chunks / pages.") def _get_value_size(self, value: Any) -> int: """Calculates the size of a cache value in bytes.""" @@ -55,33 +49,42 @@ def _get_value_size(self, value: Any) -> int: size_in_bytes = sys.getsizeof(value) return size_in_bytes - def add(self, key: CpuChunkId, value: Any) -> bool: + def add(self, chunk_id: CpuChunkId, value: Any) -> bool: """ Adds a key-value pair to the cache. If the cache is full, it evicts the least recently used, unpinned entries until there is enough space. """ + if chunk_id < 0 or chunk_id >= self.max_num_cpu_chunks: + # TODO(jcgu): report failure when offload scheduler / worker + # can handle failed operations. + raise ValueError(f" get invalid chunk_id: {chunk_id}") + # Add the new item. - if key in self.cache: - old_value = self.cache.pop(key) + if chunk_id in self.cache: + old_value = self.cache.pop(chunk_id) self.current_size_bytes -= self._get_value_size(old_value) del old_value + self.num_occupied_cpu_chunks -= 1 - self.cache[key] = value + self.cache[chunk_id] = value + self.num_occupied_cpu_chunks += 1 value_size = self._get_value_size(value) self.current_size_bytes += value_size - logger.info(f"Added key: {key} (size:{value_size}) to CPU backend.") - logger.info(f"Cache size: {self.current_size_bytes} bytes / " - f"{self.max_cpu_cache_size_bytes} bytes") + logger.info( + f"Added chunk_id: {chunk_id} (size:{value_size}) to CPU backend.") + logger.info( + f"Cache: {self.current_size_bytes} bytes, {self.num_occupied_cpu_chunks} occupied chunks." + ) return True - def get(self, key: CpuChunkId) -> Optional[Any]: + def get(self, chunk_id: CpuChunkId) -> Optional[Any]: """ - Gets the value for a given key and marks it as recently used. + Gets the value for a given chunk_id and marks it as recently used. """ - if key in self.cache: - return self.cache[key] + if chunk_id in self.cache: + return self.cache[chunk_id] return None def reclaim_unoccupied_chunks(self, occupied_chunk_ids: list[CpuChunkId]): diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index 078b7f0fa..097fb2bb7 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -139,6 +139,8 @@ # 3. dynamic: keep the partial block as is. PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal["drop", "pad", "dynamic"] +DEFAULT_TPU_OFFLOAD_CPU_CHUNKS = 1024 + @dataclass class SaveSpec: @@ -469,7 +471,11 @@ def __init__(self, vllm_config: "VllmConfig"): self.block_size = vllm_config.cache_config.block_size # offloading manager - self.offload_manager = LRUCacheManager(num_cpu_chunks=1024) + self.num_cpu_chunks = int( + os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", + str(DEFAULT_TPU_OFFLOAD_CPU_CHUNKS))) + self.offload_manager = LRUCacheManager( + num_cpu_chunks=self.num_cpu_chunks) self._request_trackers: dict[ReqId, RequestTracker] = {} # This dictionary holds the full vLLM Request object for all requests @@ -528,6 +534,7 @@ def __init__(self, vllm_config: "VllmConfig"): f"TPUOffloadConnectorScheduler initialized with: " f"block_size={self.block_size}, " f"cpu_chunk_size={self.cpu_chunk_size}, " + f"num_cpu_chunks={self.num_cpu_chunks}, " f"model_name={model_name}, " f"decode_save={self.decode_save}, " f"partial_block_save_behavior={self.partial_block_save_behavior}, " @@ -1217,7 +1224,10 @@ def __init__(self, vllm_config: VllmConfig, self.swap_out_fn: KVCacheSwapFn = None # cpu cache - self.cpu_backend = LocalCPUBackend() + self.num_cpu_chunks = int( + os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", + str(DEFAULT_TPU_OFFLOAD_CPU_CHUNKS))) + self.cpu_backend = LocalCPUBackend(num_cpu_chunks=self.num_cpu_chunks) # The worker needs its own token processor to generate keys. model_name = self.vllm_config.model_config.model logger.info( From c1c5a7b54e433b73f7af3ebaae48558ea2ae8e11 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 19 Nov 2025 23:26:20 +0000 Subject: [PATCH 05/19] offload accuracy test Signed-off-by: Juncheng Gu --- .../offload/host_offloading_accuracy_test.py | 36 ++++++++++++++++--- .../offload/tpu_offload_connector.py | 2 -- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/tests/distributed/offload/host_offloading_accuracy_test.py b/tests/distributed/offload/host_offloading_accuracy_test.py index 1793dfe74..9c7705316 100644 --- a/tests/distributed/offload/host_offloading_accuracy_test.py +++ b/tests/distributed/offload/host_offloading_accuracy_test.py @@ -28,30 +28,34 @@ def parse_outputs(outputs): def sampling_config(): """deterministic sampling config""" return SamplingParams(temperature=0, - max_tokens=10, + max_tokens=20, seed=42, ignore_eos=True) @pytest.fixture def kv_transfer_config(): - """use TPUOffloadConnector from tpu_connector_local""" + """use from tpu_connector_local""" return KVTransferConfig( kv_connector="TPUOffloadConnector", kv_role="kv_both", kv_connector_module_path= - "tpu_inference.distributed.tpu_connector_local", + "tpu_inference.distributed.offload.tpu_offload_connector", ) -def test_kv_cache_cpu_offloading_accuracy( +def _test_kv_cache_cpu_offloading_accuracy( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, kv_transfer_config: KVTransferConfig, + swap_op_type: str, + decode_save: str, ): with monkeypatch.context(): os.environ['SKIP_JAX_PRECOMPILE'] = '1' - os.environ['TPU_OFFLOAD_SWAP_OP_TYPE'] = "pallas" + os.environ['TPU_OFFLOAD_SKIP_JAX_PRECOMPILE'] = '1' + os.environ['TPU_OFFLOAD_SWAP_OP_TYPE'] = swap_op_type + os.environ['TPU_OFFLOAD_DECODE_SAVE'] = decode_save llm = LLM(model="meta-llama/Llama-3.2-3B", max_model_len=1024, tensor_parallel_size=8, @@ -80,3 +84,25 @@ def test_kv_cache_cpu_offloading_accuracy( assert text1 == text2 for tokens1, tokens2 in zip(out_tokens1, out_tokens2): assert tokens1 == tokens2 + + del llm + # Waiting for TPUs to be released. + time.sleep(20) + + +def test_kv_cache_cpu_offloading_accuracy( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + kv_transfer_config: KVTransferConfig, +): + swap_op_types = ["pallas", "jax"] + decode_saves = ["0", "1"] + for swap_op_type in swap_op_types: + for decode_save in decode_saves: + _test_kv_cache_cpu_offloading_accuracy( + monkeypatch, + sampling_config, + kv_transfer_config, + swap_op_type, + decode_save, + ) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index 097fb2bb7..e6b7f223a 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -467,7 +467,6 @@ class TPUOffloadConnectorScheduler(): def __init__(self, vllm_config: "VllmConfig"): logger.info("TPUOffloadConnectorScheduler: Entering __init__") self.vllm_config = vllm_config - self.config = vllm_config.kv_transfer_config self.block_size = vllm_config.cache_config.block_size # offloading manager @@ -1201,7 +1200,6 @@ def __init__(self, vllm_config: VllmConfig, connector: "TPUOffloadConnector"): logger.info("TPUOffloadConnectorWorker: Entering __init__") self.vllm_config = vllm_config - self.config = vllm_config.kv_transfer_config self.connector = connector self.block_size = vllm_config.cache_config.block_size From 9050d9c600437f9384935b57f330af5fe36cb956 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 19 Nov 2025 23:45:22 +0000 Subject: [PATCH 06/19] offload worker precompile test Signed-off-by: Juncheng Gu --- ...accuracy_test.py => tpu_offload_accuracy_test.py} | 0 ..._test.py => tpu_offload_connector_worker_test.py} | 12 ++---------- .../distributed/offload/tpu_offload_connector.py | 1 - 3 files changed, 2 insertions(+), 11 deletions(-) rename tests/distributed/offload/{host_offloading_accuracy_test.py => tpu_offload_accuracy_test.py} (100%) rename tests/distributed/offload/{host_offloading_precompile_test.py => tpu_offload_connector_worker_test.py} (95%) diff --git a/tests/distributed/offload/host_offloading_accuracy_test.py b/tests/distributed/offload/tpu_offload_accuracy_test.py similarity index 100% rename from tests/distributed/offload/host_offloading_accuracy_test.py rename to tests/distributed/offload/tpu_offload_accuracy_test.py diff --git a/tests/distributed/offload/host_offloading_precompile_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py similarity index 95% rename from tests/distributed/offload/host_offloading_precompile_test.py rename to tests/distributed/offload/tpu_offload_connector_worker_test.py index aa3faaf75..53121a717 100644 --- a/tests/distributed/offload/host_offloading_precompile_test.py +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -13,7 +13,6 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole -from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend from tpu_inference.distributed.offload.tpu_offload_connector import \ TPUOffloadConnector as CPUOffloadingConnector from tpu_inference.logger import init_logger @@ -32,6 +31,7 @@ def __init__(self, kv_caches: List[jax.Array], mesh: Mesh): self.mesh = mesh self.model_config = None self.sampler = None + self.devices = jax.devices() def get_kv_cache_layout(self): return "NHD" @@ -42,7 +42,6 @@ class MockVllmConfig: def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): self.model_config = self.Model() self.cache_config = self.Cache(block_size) - self.kv_transfer_config = self.KVTransfer() class Model: model = "test-model" @@ -52,12 +51,8 @@ class Cache: def __init__(self, block_size): self.block_size = block_size - class KVTransfer: - kv_ip = "localhost" - kv_port = 9999 - -class TestHostOffloadingPrecompile(jtu.JaxTestCase): +class TestTPUOffloadWorkerPrecompile(jtu.JaxTestCase): """Test the host offloading precompilation and related functionalities.""" def setUp(self): @@ -105,9 +100,6 @@ def create_mesh(self, axis_shapes, axis_names): def _create_connector(self, swap_op_type: str = "jax"): # Clean the singleton backend instance before each test - LocalCPUBackend._instance = None - LocalCPUBackend._initialized = False - os.environ["TPU_OFFLOAD_SWAP_OP_TYPE"] = swap_op_type connector = CPUOffloadingConnector(self.vllm_config, KVConnectorRole.WORKER) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index e6b7f223a..f0a9985b3 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -336,7 +336,6 @@ class TPUOffloadConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): logger.info("TPUOffloadConnector: Entering __init__") - assert vllm_config.kv_transfer_config is not None if role == KVConnectorRole.SCHEDULER: self.connector_scheduler = \ From 5cdfa3789909dd15634283b723c8d5fa1be02791 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 19 Nov 2025 23:58:17 +0000 Subject: [PATCH 07/19] dma kernel test Signed-off-by: Juncheng Gu --- tests/kernels/host_dma_test.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/kernels/host_dma_test.py b/tests/kernels/host_dma_test.py index 251eb6223..61dbf7386 100644 --- a/tests/kernels/host_dma_test.py +++ b/tests/kernels/host_dma_test.py @@ -95,11 +95,13 @@ def test_d2h_dma(self, model_axis_size: int): res = self.create_sharded_array(model_axis_size, "device") if res is None: return - original_device_data, _, host_sharding = res + original_device_data, device_sharding, host_sharding = res # 2. Test Device-to-Host (d2h) DMA - host_data = d2h_dma(original_device_data, host_sharding) + host_data = d2h_dma(original_device_data, device_sharding, + host_sharding) jax.block_until_ready(host_data) + assert host_data.sharding.memory_kind == "pinned_host" # 3. Verification assert host_data.sharding == host_sharding @@ -115,11 +117,13 @@ def test_h2d_dma(self, model_axis_size: int): res = self.create_sharded_array(model_axis_size, "host") if res is None: return - original_host_data, device_sharding, _ = res + original_host_data, device_sharding, host_sharding = res # 2. Test Host-to-Device (h2d) DMA - device_data = h2d_dma(original_host_data, device_sharding) + device_data = h2d_dma(original_host_data, host_sharding, + device_sharding) jax.block_until_ready(device_data) + assert device_data.sharding.memory_kind == "device" # 3. Verification assert device_data.sharding == device_sharding @@ -147,16 +151,20 @@ def test_d2h_h2d_dma_roundtrip(self, model_axis_size: int): original_device_data, device_sharding, host_sharding = res # 2. Test Device-to-Host (d2h) DMA - host_data = d2h_dma(original_device_data, host_sharding) + host_data = d2h_dma(original_device_data, device_sharding, + host_sharding) jax.block_until_ready(host_data) + assert host_data.sharding.memory_kind == "pinned_host" # 3. Verification for d2h assert host_data.sharding == host_sharding self.assertArraysEqual(original_device_data, host_data) # 4. Test Host-to-Device (h2d) DMA - reloaded_device_data = h2d_dma(host_data, device_sharding) + reloaded_device_data = h2d_dma(host_data, host_sharding, + device_sharding) jax.block_until_ready(reloaded_device_data) + assert reloaded_device_data.sharding.memory_kind == "device" # 5. Verification for h2d assert reloaded_device_data.sharding == device_sharding From f04f6c89272a87d5bebbfe4c6409e18f6631a93c Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 20 Nov 2025 03:24:13 +0000 Subject: [PATCH 08/19] utils test Signed-off-by: Juncheng Gu --- .../offload/cpu_offloading_cache_util_test.py | 130 --------------- .../offload/tpu_offload_utils_test.py | 157 ++++++++++++++++++ 2 files changed, 157 insertions(+), 130 deletions(-) delete mode 100644 tests/distributed/offload/cpu_offloading_cache_util_test.py create mode 100644 tests/distributed/offload/tpu_offload_utils_test.py diff --git a/tests/distributed/offload/cpu_offloading_cache_util_test.py b/tests/distributed/offload/cpu_offloading_cache_util_test.py deleted file mode 100644 index 175dd0da9..000000000 --- a/tests/distributed/offload/cpu_offloading_cache_util_test.py +++ /dev/null @@ -1,130 +0,0 @@ -import unittest - -import jax -import jax.numpy as jnp -import numpy as np - -from tpu_inference.distributed.offload.utils import \ - jitted_insert_kv_cache_slices - - -def original_jitted_insert_kv_cache_slices( - block_size, - kv_caches: list[jax.Array], - kv_cache_slices: list[jax.Array], - block_numbers: jax.Array, -) -> list[jax.Array]: - """ - This is the original implementation that expects concatenated slices. - It reshapes the single slice per layer into multiple blocks. - """ - - def _update_layer(cache, slices): - """The function to apply to each layer's cache and slices.""" - # Original method reshapes a large slice into blocks - num_blocks = len(block_numbers) - reshaped_slices = slices.reshape( - (num_blocks, 1, block_size, *slices.shape[1:])) - for i, block_idx in enumerate(block_numbers): - cache = jax.lax.dynamic_update_slice_in_dim(cache, - reshaped_slices[i], - block_idx, - axis=0) - return cache - - return jax.tree.map(_update_layer, kv_caches, kv_cache_slices) - - -class TestCacheInsertion(unittest.TestCase): - - def setUp(self): - """Set up common parameters for the tests.""" - self.num_layers = 2 - self.num_blocks_total = 32 - self.block_size = 16 - self.num_kv_heads = 4 - self.head_dim = 128 - - # We will load 3 new blocks - self.num_blocks_to_load = 3 - - # Shape for a single block in the main KV cache - self.kv_cache_shape = ( - self.num_blocks_total, - self.block_size, - self.num_kv_heads, - 2, - self.head_dim, - ) - - # Shape for one chunk/slice of tokens to be inserted - self.slice_shape = ( - self.block_size, - self.num_kv_heads, - 2, - self.head_dim, - ) - - # Destination block indices in the main KV cache - self.dst_blocks = jnp.array([5, 12, 21]) - - # --- Test Data --- - - # 1. Initial (empty) KV caches for all layers - key = jax.random.PRNGKey(0) - self.initial_kv_caches1 = [ - jnp.zeros(self.kv_cache_shape, dtype=jnp.float32) - for _ in range(self.num_layers) - ] - - self.initial_kv_caches2 = [ - jnp.zeros(self.kv_cache_shape, dtype=jnp.float32) - for _ in range(self.num_layers) - ] - - # 2. The raw, chunked KV data (input for the new method) - # This is a list of lists: List[layer -> List[chunk]] - self.raw_chunked_kv = [] - for _ in range(self.num_layers): - key, subkey = jax.random.split(key) - layer_chunks = [ - jax.random.normal(subkey, self.slice_shape) - for _ in range(self.num_blocks_to_load) - ] - self.raw_chunked_kv.append(layer_chunks) - - # 3. The concatenated KV data (input for the original method) - # This is a list of arrays: List[layer -> concatenated_array] - self.concatenated_kv = [ - jax.lax.concatenate(layer_chunks, dimension=0) - for layer_chunks in self.raw_chunked_kv - ] - - def test_jitted_insert_kv_cache_slices_equivalence(self): - """ - Verify that the new and original methods for inserting KV cache slices - produce identical results. - """ - # --- Approach 1: Original Method --- - # This method takes concatenated slices. - original_output = original_jitted_insert_kv_cache_slices( - self.block_size, self.initial_kv_caches1, self.concatenated_kv, - self.dst_blocks) - - # --- Approach 2: New Method --- - # This method takes a list of chunked slices. - new_output = jitted_insert_kv_cache_slices(self.block_size, - self.initial_kv_caches2, - self.raw_chunked_kv, - self.dst_blocks) - - # --- Verification --- - # Check that the outputs for each layer are identical. - for i in range(self.num_layers): - np.testing.assert_array_equal(np.array(original_output[i]), - np.array(new_output[i])) - print("\nTest passed: Both methods produce identical KV caches.") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/distributed/offload/tpu_offload_utils_test.py b/tests/distributed/offload/tpu_offload_utils_test.py new file mode 100644 index 000000000..75af7a3bd --- /dev/null +++ b/tests/distributed/offload/tpu_offload_utils_test.py @@ -0,0 +1,157 @@ +import functools +import itertools +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import NamedSharding, PartitionSpec + +from tpu_inference.distributed.offload.utils import ( + get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) + + +class TestTPUOffloadUtilsFn(unittest.TestCase): + + def setUp(self): + """Set up common parameters for the tests.""" + self.num_layers = 2 + self.num_tokens = 256 + self.num_kv_heads = 8 + self.head_dim = 128 + self.block_size = 16 + self.num_blocks = self.num_tokens // self.block_size + self.cache_shape = ( + self.num_blocks, + self.block_size, + self.num_kv_heads, + 2, + self.head_dim, + ) + self.block_shape = ( + self.block_size, + self.num_kv_heads, + 2, + self.head_dim, + ) + + self.cache_dtype = jnp.bfloat16 + + self.mesh = self.create_mesh((1, 8), ("data", "model")) + partition_spec = PartitionSpec(None, None, "model") + self.device_sharding = NamedSharding(self.mesh, + partition_spec, + memory_kind="device") + self.host_sharding = NamedSharding(self.mesh, + partition_spec, + memory_kind="pinned_host") + flatten_partition_spec = PartitionSpec(None, "model") + self.flatten_device_sharding = NamedSharding(self.mesh, + flatten_partition_spec, + memory_kind="device") + + def create_mesh(self, axis_shapes, axis_names): + """Creates a JAX device mesh with the default device order.""" + try: + num_required_devices = np.prod(axis_shapes) + devices = np.array(jax.devices()) + if len(devices) < num_required_devices: + self.skipTest( + f"Not enough devices to create mesh of shape {axis_shapes}." + ) + device_array = devices[:num_required_devices].reshape(axis_shapes) + return jax.sharding.Mesh(device_array, axis_names) + except RuntimeError: + return None + + def test_jitted_insert_kv_cache_slices_equivalence(self): + """ + Verify inserting scattered kv slices / pages into the large kv cache. + """ + num_blocks_to_insert = 3 + dst_blocks = [3, 5, 7] + dst_blocks_array = jnp.array(dst_blocks) + + initial_kv_caches = [ + jax.device_put(jnp.zeros(self.cache_shape, dtype=self.cache_dtype), + self.device_sharding) + for _ in range(self.num_layers) + ] + + # The raw, chunked KV data (input for the new method) + # This is a list of lists: List[layer -> List[block]] + raw_chunked_kv = [] + for i in range(self.num_layers): + layer_chunks = [ + jax.device_put( + jax.random.normal(jax.random.key(i), + shape=self.block_shape, + dtype=self.cache_dtype), + self.flatten_device_sharding) + for _ in range(num_blocks_to_insert) + ] + raw_chunked_kv.append(layer_chunks) + + output = jitted_insert_kv_cache_slices(self.block_size, + initial_kv_caches, + raw_chunked_kv, + dst_blocks_array) + + # --- Verification --- + # Check that the selected pages for each layer equal to the original ones. + for i in range(self.num_layers): + for j in range(num_blocks_to_insert): + block_id = dst_blocks[j] + np.testing.assert_array_equal(np.array(output[i][block_id]), + raw_chunked_kv[i][j]) + print("\nTest passed: the inserted kv equals to the original one.") + + def test_swap_fn_correctness(self): + """ + Verify that swap-out and swap-in functions work correctly for different + swap_op_types and jitted options. + """ + swap_op_types = ["jax", "pallas"] + jitted_options = [True, False] + + # NOTE(jcgu): we are using the entire kv cache [n_b, bs, nh, 2, hd], + # actually, we will operate on concatenated blocks [nt, nh, 2, hd]; + @functools.partial(jax.jit, out_shardings=self.device_sharding) + def create_on_device(key): + return jax.random.uniform(key, + shape=self.cache_shape, + dtype=self.cache_dtype) + + initial_kv_caches = [ + create_on_device(jax.random.key(i)) for i in range(self.num_layers) + ] + jax.block_until_ready(initial_kv_caches) + + for swap_op_type, jitted in itertools.product(swap_op_types, + jitted_options): + with self.subTest(swap_op_type=swap_op_type, jitted=jitted): + swap_in_fn, swap_out_fn = get_kv_cache_swap_fn( + swap_op_type, self.host_sharding, self.device_sharding, + jitted) + + # Put initial data on device + device_kv_caches = jax.device_put(initial_kv_caches, + self.device_sharding) + jax.block_until_ready(device_kv_caches) + + # Swap out to host + host_kv_caches = swap_out_fn(device_kv_caches) + + # Swap back in to device + final_device_kv_caches = swap_in_fn(host_kv_caches) + jax.block_until_ready(final_device_kv_caches) + + # Verify correctness + for i in range(self.num_layers): + np.testing.assert_array_equal( + np.array(initial_kv_caches[i]), + np.array(final_device_kv_caches[i])) + + +if __name__ == "__main__": + unittest.main() From 0db26d80bab17571abf0dd3400979d9ccc9ca8d3 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 20 Nov 2025 05:32:29 +0000 Subject: [PATCH 09/19] connector worker tests Signed-off-by: Juncheng Gu --- .../offload/cpu_offloading_worker_test.py | 942 +++--------------- .../distributed/offload/cpu_backend.py | 12 +- .../offload/tpu_offload_connector.py | 1 + 3 files changed, 155 insertions(+), 800 deletions(-) diff --git a/tests/distributed/offload/cpu_offloading_worker_test.py b/tests/distributed/offload/cpu_offloading_worker_test.py index 8ed20f12b..26df906d0 100644 --- a/tests/distributed/offload/cpu_offloading_worker_test.py +++ b/tests/distributed/offload/cpu_offloading_worker_test.py @@ -14,12 +14,11 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole -from tpu_inference.distributed.cpu_backend import LocalCPUBackend -from tpu_inference.distributed.offload.tpu_connector_local import (LoadSpec, - SaveSpec) -from tpu_inference.distributed.offload.tpu_connector_local import \ +from tpu_inference.distributed.offload.tpu_offload_connector import (LoadSpec, + SaveSpec) +from tpu_inference.distributed.offload.tpu_offload_connector import \ TPUOffloadConnector as CPUOffloadingConnector -from tpu_inference.distributed.offload.tpu_connector_local import ( +from tpu_inference.distributed.offload.tpu_offload_connector import ( TPUOffloadConnectorMetadata, TPUReqMeta) from tpu_inference.logger import init_logger from tpu_inference.runner.tpu_jax_runner import TPUModelRunner @@ -37,6 +36,7 @@ def __init__(self, kv_caches: List[jax.Array], mesh: Mesh): self.mesh = mesh self.model_config = None self.sampler = None + self.devices = jax.devices() def get_kv_cache_layout(self): return "NHD" @@ -47,7 +47,6 @@ class MockVllmConfig: def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): self.model_config = self.Model() self.cache_config = self.Cache(block_size) - self.kv_transfer_config = self.KVTransfer() class Model: model = "test-model" @@ -57,10 +56,6 @@ class Cache: def __init__(self, block_size): self.block_size = block_size - class KVTransfer: - kv_ip = "localhost" - kv_port = 9999 - class TestCpuOffloadingSave(jtu.JaxTestCase): """Test the save functionality of the TPUOffloadConnectorWorker.""" @@ -70,6 +65,7 @@ def setUp(self): self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE) self.num_layers = 2 self.num_blocks = 24 + self.num_cpu_chunks = 24 self.block_size = self.vllm_config.cache_config.block_size self.num_heads = 8 self.head_size = 128 @@ -108,12 +104,14 @@ def create_mesh(self, axis_shapes, axis_names): except RuntimeError: return None - def _create_connector(self, swap_op_type: str = "jax"): - # Clean the singleton backend instance before each test - LocalCPUBackend._instance = None - LocalCPUBackend._initialized = False - + def _create_connector(self, + swap_op_type: str = "jax", + use_precompiled_swap_ops: bool = False): os.environ["TPU_OFFLOAD_SWAP_OP_TYPE"] = swap_op_type + os.environ[ + "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1" + os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str(self.num_cpu_chunks) + connector = CPUOffloadingConnector(self.vllm_config, KVConnectorRole.WORKER) worker = connector.connector_worker @@ -135,266 +133,30 @@ def create_on_device(key): worker.register_runner(mock_runner) return connector - def _verify_saved_data( - self, - worker, - source_kv_cache, - process_token_ids, - local_block_ids, - skip_leading_tokens, - num_layers, - block_size, - ): - cpu_backend = LocalCPUBackend._instance - token_processor = worker.token_processor - - # Create a map from token index to its corresponding chunk key info - token_to_chunk_map = {} - all_keys_generator = token_processor.process_tokens(process_token_ids) - for abs_start_idx, abs_end_idx, key in all_keys_generator: - for i in range(abs_start_idx, abs_end_idx): - token_to_chunk_map[i] = { - "key": key, - "start_idx": abs_start_idx, - } - - # Cache fetched chunks to avoid getting the same chunk multiple times - fetched_chunks = {} - - num_processed_tokens = len(process_token_ids) - for token_idx in range(skip_leading_tokens, num_processed_tokens): - if token_idx not in token_to_chunk_map: - self.fail(f"Token index {token_idx} not found in any chunk.") - - chunk_info = token_to_chunk_map[token_idx] - chunk_key = chunk_info["key"] - - # Fetch the chunk from backend if not already fetched - if chunk_key not in fetched_chunks: - chunk_data = cpu_backend.get(chunk_key) - self.assertIsNotNone( - chunk_data, - f"Key {chunk_key} for token {token_idx} not found in backend", - ) - fetched_chunks[chunk_key] = chunk_data - - saved_chunk_data = fetched_chunks[chunk_key] - - # Get the original data from the source TPU cache - logical_block_idx = token_idx // block_size - block_offset = token_idx % block_size - physical_block_id = local_block_ids[logical_block_idx] - - # Get the saved data for the specific token from the chunk - offset_in_chunk = token_idx - chunk_info["start_idx"] - - assert offset_in_chunk == block_offset, f"{offset_in_chunk} != {block_offset}" - for layer_idx in range(num_layers): - original_token_data = source_kv_cache[layer_idx][ - physical_block_id, block_offset, ...] - saved_chunk = jax.device_put(saved_chunk_data[layer_idx], - jax.devices("cpu")[0]) - saved_token_data = saved_chunk[offset_in_chunk, ...] - self.assertArraysEqual(np.array(saved_token_data), - np.array(original_token_data)) - @parameterized.named_parameters( dict( - testcase_name="_prefill_no_skip_save_2_drop_jax", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=0, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=2, + testcase_name="_regular_multi_block_save", + num_blocks_to_save=5, ), dict( - testcase_name="_prefill_no_skip_save_2_drop_jax_precompiled", + testcase_name="_regular_multi_block_save_with_compile_jax", + num_blocks_to_save=5, use_precompiled_swap_ops=True, - num_skip_leading_tokens=0, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=2, ), dict( - testcase_name="_prefill_no_skip_save_2_drop_pallas", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=0, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=2, - swap_op_type="pallas", - ), - dict( - testcase_name="_prefill_no_skip_save_2_drop_pallas_precompiled", + testcase_name="_regular_multi_block_save_with_compile_pallas", + num_blocks_to_save=5, use_precompiled_swap_ops=True, - num_skip_leading_tokens=0, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=2, swap_op_type="pallas", ), - # NOTE(jcgu): mimic the scenario of padding the last partial - # block when the preferred saving behavior is pad / dynamic: - # add 10 extra tokens (after 2 full blocks) as the partial - # block and assign 3 blocks to save. - dict( - testcase_name="_prefill_no_skip_save_2_pad_jax", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=0, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=3, - ), - dict( - testcase_name="_prefill_no_skip_save_2_pad_jax_precompiled", - use_precompiled_swap_ops=True, - num_skip_leading_tokens=0, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=3, - ), dict( - testcase_name="_prefill_no_skip_save_2_pad_pallas", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=0, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=3, - swap_op_type="pallas", - ), - dict( - testcase_name="_prefill_no_skip_save_2_pad_pallas_precompiled", - use_precompiled_swap_ops=True, - num_skip_leading_tokens=0, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=3, - swap_op_type="pallas", - ), - dict( - testcase_name="_prefill_skip_2_save_2_drop", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10, - num_blocks_to_save=2, - ), - dict( - testcase_name="_prefill_skip_2_save_2_drop_precompiled", - use_precompiled_swap_ops=True, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10, - num_blocks_to_save=2, - ), - dict( - testcase_name="_prefill_skip_2_save_2_pad", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10, - num_blocks_to_save=3, - ), - dict( - testcase_name="_prefill_skip_2_save_2_pad_precompiled", - use_precompiled_swap_ops=True, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 4 + 10, - num_blocks_to_save=3, - ), - dict( - testcase_name="_decode_skip_3_save_1", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 3, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 4, - num_blocks_to_save=1, - ), - dict( - testcase_name="_decode_skip_3_save_1_precompiled", - use_precompiled_swap_ops=True, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 3, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 4, - num_blocks_to_save=1, - ), - dict( - testcase_name="_no_save", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=0, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_blocks_to_save=0, - is_final_save=False, - skip_save=False, - ), - dict( - testcase_name="_no_save_precompiled", - use_precompiled_swap_ops=True, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=0, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_blocks_to_save=0, - is_final_save=False, - skip_save=False, - ), - dict( - testcase_name="_final_save_save_1_drop", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 3 + 10, + testcase_name="_final_save", num_blocks_to_save=1, is_final_save=True, skip_save=False, ), dict( - testcase_name="_final_save_save_1_drop_precompiled", - use_precompiled_swap_ops=True, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=_DEFAULT_BLOCK_SIZE, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 3 + 10, - num_blocks_to_save=1, - is_final_save=True, - skip_save=False, - ), - dict( - testcase_name="_final_save_save_1_pad", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=10, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=1, - is_final_save=True, - skip_save=False, - ), - dict( - testcase_name="_final_save_save_1_pad_precompiled", - use_precompiled_swap_ops=True, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=10, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2 + 10, - num_blocks_to_save=1, - is_final_save=True, - skip_save=False, - ), - dict( - testcase_name="_final_save_without_data", - use_precompiled_swap_ops=False, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=0, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_blocks_to_save=0, - is_final_save=True, - skip_save=True, - ), - dict( - testcase_name="_final_save_without_data_precompiled", - use_precompiled_swap_ops=True, - num_skip_leading_tokens=_DEFAULT_BLOCK_SIZE * 2, - num_tokens_to_save=0, - num_total_tokens=_DEFAULT_BLOCK_SIZE * 2, + testcase_name="_final_skip_save", num_blocks_to_save=0, is_final_save=True, skip_save=True, @@ -402,80 +164,57 @@ def _verify_saved_data( ) def test_tpu_connector_save( self, - use_precompiled_swap_ops: bool, - num_skip_leading_tokens: int, - num_tokens_to_save: int, - num_total_tokens: int, num_blocks_to_save: int, is_final_save: bool = False, skip_save: bool = False, + use_precompiled_swap_ops: bool = False, swap_op_type: str = "jax", ): - os.environ[ - "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1" - - # Prepare and Execute Save - total_token_ids = list(range(num_total_tokens)) - num_blocks_for_tokens = (num_total_tokens + self.block_size - - 1) // self.block_size - - if num_blocks_for_tokens > self.num_blocks: - self.skipTest( - f"Not enough blocks to run test, blocks for tokens {num_blocks_for_tokens} > {self.num_blocks}" - ) - if num_blocks_for_tokens < num_blocks_to_save: + if num_blocks_to_save > self.num_blocks or num_blocks_to_save > self.num_cpu_chunks: self.skipTest( - f"Not enough blocks to save, blocks for tokens {num_blocks_for_tokens} < {num_blocks_to_save}" + f"num_blocks_to_save {num_blocks_to_save} exceeds ModelRunner / OffloadConnectorWorker's capacity" ) - if num_skip_leading_tokens % self.block_size != 0: - self.skipTest( - "num_skip_leading_tokens must be a multiple of block_size") - if num_total_tokens < (num_skip_leading_tokens + num_tokens_to_save): - self.skipTest( - f"num_total_tokens {num_total_tokens} must be no less than num_skip_leading_tokens + num_tokens_to_save" - ) - if (num_blocks_to_save - - (num_tokens_to_save // self.block_size)) not in [0, 1]: - self.skipTest( - f"num_blocks_to_save {num_blocks_to_save} does not match with the given num_tokens_to_save {num_tokens_to_save}" - ) - - total_blocks = list(range(self.num_blocks)) - local_block_ids = sorted( - random.sample(total_blocks, num_blocks_for_tokens)) - num_skip_blocks = num_skip_leading_tokens // self.block_size - src_blocks_to_save = local_block_ids[num_skip_blocks:( - num_skip_blocks + num_blocks_to_save)] - - logger.info( - f"Starting test_tpu_connector_save with: " - f"num_blocks_to_save={num_blocks_to_save}, skip_leading_tokens={num_skip_leading_tokens}, num_tokens_to_save={num_tokens_to_save}, " - f"num_total_tokens={num_total_tokens}, is_final_save={is_final_save}, skip_save={skip_save}, swap_op_type={swap_op_type}. \n" - f" Prepared for save: total_token_ids={total_token_ids}, num_blocks_for_tokens={num_blocks_for_tokens}, " - f"blocks_to_save={src_blocks_to_save}, local_block_ids={local_block_ids}." - ) - - connector = self._create_connector(swap_op_type) - worker = connector.connector_worker - req_id = "save_req" - num_process_tokens = num_skip_leading_tokens + num_tokens_to_save + # Prepare and Execute Save + all_block_ids = list(range(self.num_blocks)) + all_chunk_ids = list(range(self.num_cpu_chunks)) + src_block_ids = random.sample(all_block_ids, num_blocks_to_save) + dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_save) + num_tokens_to_save = num_blocks_to_save * self.block_size + num_total_tokens = num_tokens_to_save save_spec = SaveSpec( - num_skip_leading_tokens=num_skip_leading_tokens, - num_total_tokens=num_process_tokens, + num_skip_leading_tokens=0, + num_total_tokens=num_total_tokens, is_final_save=is_final_save, skip_save=skip_save, - src_blocks=src_blocks_to_save, + src_blocks=src_block_ids, + dst_chunks=dst_chunk_ids, ) + + logger.info(f"Starting test_tpu_connector_save with: " + f"num_blocks_to_save={num_blocks_to_save}, " + f"is_final_save={is_final_save}, " + f"skip_save={skip_save}, " + f"use_precompiled_swap_ops={use_precompiled_swap_ops}, " + f"swap_op_type={swap_op_type};" + f"Swapspec: {save_spec}") + + total_token_ids = list(range(num_total_tokens)) + + req_id = "save_req" req_meta = TPUReqMeta( req_id=req_id, token_ids=total_token_ids, - local_block_ids=local_block_ids, + local_block_ids=src_block_ids, save_spec=save_spec, ) connector_metadata = TPUOffloadConnectorMetadata( requests_meta=[req_meta]) + + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops) + worker = connector.connector_worker connector.bind_connector_metadata(connector_metadata) logger.info( "Connector metadata bound, calling worker.wait_for_save().") @@ -485,46 +224,27 @@ def test_tpu_connector_save( # Verification logger.info("Starting verification phase.") cpu_backend = worker.cpu_backend - saved_keys = cpu_backend.cache.keys() + kv_caches = worker.runner.kv_caches - if num_tokens_to_save == 0 or skip_save: - logger.info( - f"num_tokens_to_save is 0 or skip_save is True. Asserting no keys saved. " - f"Saved keys: {saved_keys}") - self.assertEmpty(saved_keys) - if is_final_save: - finished_saves, _ = worker.get_finished() - logger.info( - f"is_final_save is True. Finished requests: {finished_saves}" - ) - self.assertIn(req_id, finished_saves) - logger.info("Verification completed for no-save scenario.") + if skip_save or num_tokens_to_save == 0: + logger.info(" no blocks to save") + assert cpu_backend.num_saved_cpu_chunks == 0 + self.assertEmpty(worker.finished_save_reqs) + self.assertEmpty(worker.offload_stats.data["finished_save_blocks"]) return - # Verify that the correct number of chunks were saved - processed_token_ids = total_token_ids[:num_process_tokens] - token_processor = worker.token_processor - all_keys_generator = token_processor.process_tokens( - processed_token_ids) - expected_num_keys = 0 - for start_idx, _, _ in all_keys_generator: - # The logic in _save_blocks_to_cpu filters keys based on the start - # of the chunk. - if start_idx >= num_skip_leading_tokens: - expected_num_keys += 1 - logger.info( - f"Expected number of saved keys: {expected_num_keys}, Actual saved keys: {len(saved_keys)}" - ) - self.assertLen(saved_keys, expected_num_keys) - self._verify_saved_data( - worker, - worker.runner.kv_caches, - processed_token_ids, - local_block_ids, - num_skip_leading_tokens, - self.num_layers, - self.block_size, - ) + # verify the saved chunks + assert req_id in worker.offload_stats.data["finished_save_blocks"] + assert dst_chunk_ids == worker.offload_stats.data[ + "finished_save_blocks"][req_id] + + for tpu_block_id, cpu_chunk_id in zip(src_block_ids, dst_chunk_ids): + cpu_kv_chunk = cpu_backend.get(cpu_chunk_id) + for layer_idx in range(self.num_layers): + tpu_kv_block = kv_caches[layer_idx][tpu_block_id] + self.assertArraysEqual(np.array(tpu_kv_block), + np.array(cpu_kv_chunk[layer_idx])) + logger.info("Saved data verification completed.") if is_final_save: @@ -532,310 +252,49 @@ def test_tpu_connector_save( logger.info( f"is_final_save is True. Finished requests: {finished_saves}") self.assertIn(req_id, finished_saves) - logger.info("Test test_tpu_connector_save completed successfully.") - - @parameterized.named_parameters( - dict( - testcase_name="_2_steps_nobucket", - use_precompiled_swap_ops=False, - num_blocks_step1=2, - num_blocks_step2=1, - ), - dict( - testcase_name="_2_steps_bucketed_precompiled", - use_precompiled_swap_ops=True, - num_blocks_step1=2, - num_blocks_step2=1, - ), - dict( - testcase_name="_zero_token_step2", - use_precompiled_swap_ops=False, - num_blocks_step1=2, - num_blocks_step2=0, - ), - dict( - testcase_name="_zero_token_step2_bucketed_precompiled", - use_precompiled_swap_ops=True, - num_blocks_step1=2, - num_blocks_step2=0, - ), - ) - def test_tpu_connector_multi_step_save( - self, - use_precompiled_swap_ops: bool, - num_blocks_step1: int, - num_blocks_step2: int, - ): - """ - Tests that the TPUOffloadConnectorWorker correctly saves the KV cache in multiple - steps, respecting the save watermark (skip_leading_tokens). - """ - os.environ[ - "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1" - num_tokens_step1 = num_blocks_step1 * self.block_size - num_tokens_step2 = num_blocks_step2 * self.block_size - logger.info( - f"Starting test_tpu_connector_multi_step_save with " - f"num_tokens_step1={num_tokens_step1}, num_tokens_step2={num_tokens_step2}" - ) - - connector = self._create_connector() - worker = connector.connector_worker - available_blocks = list(range(self.num_blocks)) - - # --- Step 1: Initial Save --- - logger.info("--- Multi-step save: Step 1 ---") - skip_leading_tokens_step1 = 0 - - total_tokens_step1 = num_tokens_step1 - token_ids_step1 = list(range(total_tokens_step1)) - logger.info( - f"Step 1: num_tokens_step1={num_tokens_step1}, total_tokens_step1={total_tokens_step1}, num_blocks_step1={num_blocks_step1}" - ) - if num_blocks_step1 > self.num_blocks: - self.skipTest("Not enough blocks for step 1") - - local_block_ids_step1 = sorted( - random.sample(available_blocks, num_blocks_step1)) - num_skip_blocks_step1 = skip_leading_tokens_step1 // self.block_size - src_blocks_to_save_step1 = local_block_ids_step1[ - num_skip_blocks_step1:(num_skip_blocks_step1 + num_blocks_step1)] - - logger.info( - f"Step 1: local_block_ids_step1={local_block_ids_step1}, src_blocks_to_save_step1={src_blocks_to_save_step1}" - ) - - req_id = "multi_step_save_req" - save_spec_step1 = SaveSpec( - num_skip_leading_tokens=skip_leading_tokens_step1, - num_total_tokens=total_tokens_step1, - is_final_save=False, - skip_save=False, - src_blocks=src_blocks_to_save_step1, - ) - - req_meta_step1 = TPUReqMeta( - req_id=req_id, - token_ids=token_ids_step1, - local_block_ids=local_block_ids_step1, - save_spec=save_spec_step1, - ) - logger.info( - f"Step 1: req_meta_step1.token_ids={req_meta_step1.token_ids}, req_meta_step1.local_block_ids={req_meta_step1.local_block_ids}, req_meta_step1.save_spec.skip_leading_tokens={req_meta_step1.save_spec.num_skip_leading_tokens}" - ) - connector_metadata_step1 = TPUOffloadConnectorMetadata( - requests_meta=[req_meta_step1]) - connector.bind_connector_metadata(connector_metadata_step1) - worker.wait_for_save() - - # Verification for Step 1 - logger.info("--- Verifying Step 1 ---") - self._verify_saved_data( - worker, - worker.runner.kv_caches, - token_ids_step1, - local_block_ids_step1, - skip_leading_tokens_step1, - self.num_layers, - self.block_size, - ) - - # --- Step 2: Incremental Save --- - logger.info("--- Multi-step save: Step 2 ---") - skip_leading_tokens_step2 = num_tokens_step1 - total_tokens_step2 = skip_leading_tokens_step2 + num_tokens_step2 - token_ids_step2 = list(range(total_tokens_step2)) - num_blocks_step2_total = total_tokens_step2 // self.block_size - logger.info( - f"Step 2: num_tokens_step2={num_tokens_step2}, skip_leading_tokens_step2={skip_leading_tokens_step2}, total_tokens_step2={total_tokens_step2}, num_blocks_step2_total={num_blocks_step2_total}" - ) - if num_blocks_step2_total > self.num_blocks: - self.skipTest("Not enough blocks for step 2") - - num_additional_blocks = num_blocks_step2_total - num_blocks_step1 - logger.info(f"Step 2: num_additional_blocks={num_additional_blocks}") - remaining_blocks = sorted( - list(set(available_blocks) - set(local_block_ids_step1))) - if num_additional_blocks > len(remaining_blocks): - self.skipTest("Not enough remaining blocks for step 2") - - additional_blocks = random.sample(remaining_blocks, - num_additional_blocks) - local_block_ids_step2 = local_block_ids_step1 + additional_blocks - src_blocks_to_save_step2 = additional_blocks - logger.info( - f"Step 2: local_block_ids_step2={local_block_ids_step2}, src_blocks_to_save_step2={src_blocks_to_save_step2}" - ) - - save_spec_step2 = SaveSpec( - num_skip_leading_tokens=skip_leading_tokens_step2, - num_total_tokens=total_tokens_step2, - is_final_save=False, - skip_save=False, - src_blocks=src_blocks_to_save_step2, - ) - req_meta_step2 = TPUReqMeta( - req_id=req_id, - token_ids=token_ids_step2, - local_block_ids=local_block_ids_step2, - save_spec=save_spec_step2, - ) - logger.info( - f"Step 2: req_meta_step2.token_ids={req_meta_step2.token_ids}, req_meta_step2.local_block_ids={req_meta_step2.local_block_ids}, req_meta_step2.save_spec.skip_leading_tokens={req_meta_step2.save_spec.num_skip_leading_tokens}" - ) - connector_metadata_step2 = TPUOffloadConnectorMetadata( - requests_meta=[req_meta_step2]) - - # Manually reset worker state to simulate a new scheduler step - worker._processed_save_for_step = False - connector.bind_connector_metadata(connector_metadata_step2) - worker.wait_for_save() - - # Verification for Step 2 (only the new data) - logger.info("--- Verifying Step 2 (new data) ---") - self._verify_saved_data( - worker, - worker.runner.kv_caches, - token_ids_step2, - local_block_ids_step2, - skip_leading_tokens_step2, - self.num_layers, - self.block_size, - ) - - # Verification for Step 1 data (to ensure it is not corrupted) - logger.info("--- Verifying Step 1 data after Step 2 ---") - self._verify_saved_data( - worker, - worker.runner.kv_caches, - token_ids_step1, - local_block_ids_step1, - skip_leading_tokens_step1, - self.num_layers, - self.block_size, - ) - logger.info( - "Test test_tpu_connector_multi_step_save completed successfully.") @parameterized.named_parameters( dict( - testcase_name="_full_load_jax", - use_precompiled_swap_ops=False, - swap_op_type="jax", - num_matched_blocks=4, - num_computed_blocks=0, + testcase_name="_single_block_", + num_blocks_to_operate=1, ), dict( - testcase_name="_full_load_jax_precompiled", + testcase_name="_multi_blocks_compile_jax", + num_blocks_to_operate=5, use_precompiled_swap_ops=True, swap_op_type="jax", - num_matched_blocks=4, - num_computed_blocks=0, - ), - dict( - testcase_name="_delta_load_jax", - use_precompiled_swap_ops=False, - swap_op_type="jax", - num_matched_blocks=4, - num_computed_blocks=1, ), dict( - testcase_name="_delta_load_jax_precompiled", + testcase_name="_multi_blocks_compile_pallas", + num_blocks_to_operate=5, use_precompiled_swap_ops=True, - swap_op_type="jax", - num_matched_blocks=4, - num_computed_blocks=1, - ), - dict( - testcase_name="_delta_load_pallas", - use_precompiled_swap_ops=False, swap_op_type="pallas", - num_matched_blocks=4, - num_computed_blocks=1, - ), - dict( - testcase_name="_delta_load_pallas_precompiled", - use_precompiled_swap_ops=True, - swap_op_type="pallas", - num_matched_blocks=4, - num_computed_blocks=1, - ), - dict( - testcase_name="_no_load_jax", - use_precompiled_swap_ops=False, - swap_op_type="jax", - num_matched_blocks=1, - num_computed_blocks=1, - ), - dict( - testcase_name="_no_load_jax_precompiled", - use_precompiled_swap_ops=True, - swap_op_type="jax", - num_matched_blocks=1, - num_computed_blocks=1, ), ) def test_tpu_connector_load( self, - use_precompiled_swap_ops: bool, - swap_op_type: str, - num_matched_blocks: int, - num_computed_blocks: int = 0, + num_blocks_to_operate: int, + use_precompiled_swap_ops: bool = False, + swap_op_type: str = "jax", ): """ - Tests that the TPUOffloadConnectorWorker correctly loads only the delta of - the KV cache when a prefix is already computed by vLLM. - - This test simulates a scenario where vLLM has already computed a certain - number of tokens (prefix) and the TPUOffloadConnectorWorker needs to load - only the remaining "delta" of the KV cache from the CPU backend. + This test simulates a scenario where some amount of blocks get + offloaded to cpu cache, and then get loaded into tpu kv cache. + Both swap-out and swap-in are tested. Steps: - 1. Setup: - - Create a device mesh and sharding configurations. - - Instantiate a TPUOffloadConnector with a worker role. - - Create mock source (ground truth) and destination KV caches on the TPU. - - Register a mock TPUModelRunner with the worker. - - 2. Populate CPU Cache: - - Simulate a save operation to the CPU backend for the "matched" prefix. - - This represents the KV cache state on the CPU that corresponds to - the tokens already processed by vLLM. - - 3. Prepare and Execute Delta Load: - - Calculate the number of tokens to load (the delta). - - Construct the necessary metadata (`TPUOffloadConnectorMetadata`) and `LoadSpec` - to trigger a delta load operation, skipping the already computed tokens. - - Bind this metadata to the connector and call the worker's `start_load_kv` - method to perform the host-to-device (h2d) load for the delta tokens. - - 4. Verification: - - If no tokens were expected to be loaded, assert that the destination - KV cache remains zero. - - Otherwise, extract the expected delta data from the source KV cache - and the actually loaded data from the destination KV cache. - - Compare these two sets of data to ensure the loaded delta is correct. - - Assert that the parts of the destination cache that should not have - been touched remain zero. + 1. Setup: + 2. Simulate a save operation + 3. Load the data + 4. Verification """ - os.environ[ - "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1" - num_matched_tokens = num_matched_blocks * self.block_size - num_computed_tokens = num_computed_blocks * self.block_size - if num_matched_blocks > self.num_blocks: - self.skipTest( - f"num_matched_blocks {num_matched_blocks} > vllm_config.num_blocks {self.num_blocks}" - ) - if num_computed_blocks > num_matched_blocks: + if num_blocks_to_operate > self.num_blocks or num_blocks_to_operate > self.num_cpu_chunks: self.skipTest( - f"num_computed_blocks {num_computed_blocks} > num_matched_blocks {num_matched_blocks}" + f"num_blocks_to_save {num_blocks_to_operate} exceeds ModelRunner / OffloadConnectorWorker's capacity" ) - - logger.info( - f"Starting test_tpu_connector_load with num_computed_tokens={num_computed_tokens}, num_matched_tokens={num_matched_tokens}, swap_op_type={swap_op_type}." - ) # 1. Setup - connector = self._create_connector(swap_op_type) + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops) worker = connector.connector_worker # Ground truth cache on TPU src_kv_cache = worker.runner.kv_caches @@ -847,174 +306,65 @@ def test_tpu_connector_load( ] jax.block_until_ready(dst_kv_cache) + # Prepare + all_block_ids = list(range(self.num_blocks)) + all_chunk_ids = list(range(self.num_cpu_chunks)) + src_block_ids = random.sample(all_block_ids, num_blocks_to_operate) + dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_operate) + num_tokens_to_save = num_blocks_to_operate * self.block_size + num_total_tokens = num_tokens_to_save + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=num_tokens_to_save, + is_final_save=False, + skip_save=False, + src_blocks=src_block_ids, + dst_chunks=dst_chunk_ids, + ) + total_token_ids = list(range(num_total_tokens)) req_id = "save_req" - matched_token_ids = list(range(num_matched_tokens)) - total_blocks = list(range(self.num_blocks)) - local_block_ids = sorted( - random.sample(total_blocks, num_matched_blocks)) - - # 2. Populate CPU Cache - # Save the part of the source cache that represents the "matched" prefix - if num_matched_tokens > 0: - logger.info( - f"Populating CPU cache with {num_matched_tokens} matched tokens." - ) - - src_blocks_to_save = local_block_ids[ - num_computed_blocks:num_matched_blocks] - save_spec = SaveSpec( - num_skip_leading_tokens=num_computed_tokens, - num_total_tokens=num_matched_tokens, - is_final_save=False, - skip_save=False, - src_blocks=src_blocks_to_save, - ) - req_meta = TPUReqMeta( - req_id=req_id, - token_ids=matched_token_ids, - local_block_ids=local_block_ids, - save_spec=save_spec, - ) - connector_metadata = TPUOffloadConnectorMetadata( - requests_meta=[req_meta]) - connector.bind_connector_metadata(connector_metadata) - worker.wait_for_save() - logger.info( - f"Simulated save operation to CPU for {num_matched_tokens} tokens." - ) - else: - logger.info("No matched tokens, skipping CPU cache population.") - - # 3. Prepare and Execute Delta Load - worker.runner.kv_caches = dst_kv_cache - num_tokens_to_load = max(0, num_matched_tokens - num_computed_tokens) - # `num_tokens_to_load` cannot be negative. If `num_computed_tokens` - # is greater than or equal to `num_matched_tokens`, it means all - # relevant tokens are already on the TPU, and no new tokens need - # to be loaded from the CPU backend. In such cases, the value should - # be clamped to 0. - logger.info( - f"Calculated num_tokens_to_load: {num_tokens_to_load} (num_matched_tokens={num_matched_tokens} - num_computed_tokens={num_computed_tokens})" + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_block_ids, + save_spec=save_spec, ) - if num_tokens_to_load > 0: - dst_blocks = local_block_ids[ - num_computed_blocks:num_matched_blocks] - load_spec = LoadSpec( - num_matched_tokens=num_matched_tokens, - dst_blocks=dst_blocks, - can_load=True, - num_skip_leading_tokens=num_computed_tokens, - ) - - logger.info(f"LoadSpec created: {load_spec}") - # The worker needs the full token list to generate keys correctly - req_meta = TPUReqMeta( - req_id="load_req", - token_ids=matched_token_ids, - local_block_ids=local_block_ids, - load_spec=load_spec, - ) - connector_metadata = TPUOffloadConnectorMetadata( - requests_meta=[req_meta]) - connector.bind_connector_metadata(connector_metadata) - logger.info("Connector metadata bound, calling start_load_kv.") - worker.start_load_kv(fwd_ctx=None) - jax.block_until_ready(worker.runner.kv_caches) - logger.info("start_load_kv completed and blocked until ready.") - # we will donate the original kv_cache ref - dst_kv_cache = worker.runner.kv_caches - - # worker.runner.kv_caches = src_kv_cache - - # 4. Verification - logger.info("Starting verification phase.") - - if num_tokens_to_load <= 0: - logger.info( - "num_tokens_to_load is 0 or less, asserting nothing was loaded." - ) - # Assert that the entire destination cache remains untouched (all zeros). - for i in range(self.num_layers): - self.assertArraysEqual( - dst_kv_cache[i], - jnp.zeros(self.cache_shape, dtype=self.cache_dtype), - ) - logger.info("Assertion passed: Destination KV cache is all zeros.") - return - - # Helper to flatten and extract a token range from a cache given a block map - def get_token_slice(kv_cache, start_token, num_tokens, block_map): - if num_tokens <= 0: - return jnp.empty((0, *kv_cache.shape[2:]), - dtype=kv_cache.dtype) - start_block_logical = start_token // self.block_size - start_offset = start_token % self.block_size - end_token = start_token + num_tokens - end_block_logical = (end_token + self.block_size - - 1) // self.block_size - - if end_block_logical > len(block_map): - raise ValueError( - f"Not enough blocks in block_map to satisfy token range. " - f"Need {end_block_logical} blocks, but map has {len(block_map)}." - ) - - physical_blocks_to_gather = [ - block_map[i] - for i in range(start_block_logical, end_block_logical) - ] - - flat_cache = kv_cache[physical_blocks_to_gather, - ...].reshape(-1, *kv_cache.shape[2:]) - return flat_cache[start_offset:start_offset + num_tokens, ...] - - # Get the ground truth data from the source cache - expected_data_from_source_tpu = [ - get_token_slice( - src_kv_cache[i], - start_token=num_computed_tokens, - num_tokens=num_tokens_to_load, - block_map=local_block_ids, - ) for i in range(self.num_layers) - ] + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) + connector.bind_connector_metadata(connector_metadata) logger.info( - f"Extracted expected data from source cache. Shape of first layer: {expected_data_from_source_tpu[0].shape}" - ) + "Connector metadata bound, calling worker.wait_for_save().") + worker.wait_for_save() + logger.info("worker.wait_for_save() completed.") - # Get the data that was actually loaded into the destination cache - loaded_data_on_dest_tpu = [ - get_token_slice( - dst_kv_cache[i], - start_token=num_computed_tokens, - num_tokens=num_tokens_to_load, - block_map=local_block_ids, - ) for i in range(self.num_layers) - ] - logger.info( - f"Extracted loaded data from destination cache. Shape of first layer: {loaded_data_on_dest_tpu[0].shape}" + # 3. Prepare and Execute Delta Load + worker.runner.kv_caches = dst_kv_cache + load_spec = LoadSpec( + num_matched_tokens=num_tokens_to_save, + dst_blocks=src_block_ids, + src_chunks=dst_chunk_ids, + can_load=True, + num_skip_leading_tokens=0, ) - - # Assert that the loaded delta is correct. This works for no-load cases too. - for i in range(self.num_layers): - self.assertArraysEqual(np.array(expected_data_from_source_tpu[i]), - np.array(loaded_data_on_dest_tpu[i])) - logger.info("Assertion passed: Loaded delta matches expected data.") - - # Assert that blocks not in local_block_ids are still zero - untouched_blocks = sorted( - list(set(range(self.num_blocks)) - set(local_block_ids))) - logger.info( - f"Asserting that {len(untouched_blocks)} untouched blocks are still zero." + req_meta = TPUReqMeta( + req_id="load_req", + token_ids=total_token_ids, + local_block_ids=src_block_ids, + load_spec=load_spec, ) - if untouched_blocks: - for i in range(self.num_layers): - zero_slice = worker.runner.kv_caches[i][untouched_blocks, ...] - self.assertTrue(jnp.all(zero_slice == 0)) - expected_zeros = jnp.zeros( - (len(untouched_blocks), *self.cache_shape[1:]), - dtype=self.cache_dtype) - self.assertArraysEqual(np.array(zero_slice), - np.array(expected_zeros)) - logger.info("Assertion passed: Untouched blocks are zero.") - logger.info( - "Test test_tpu_connector_delta_load completed successfully.") + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) + connector.bind_connector_metadata(connector_metadata) + logger.info("Connector metadata bound, calling start_load_kv.") + worker.start_load_kv(fwd_ctx=None) + jax.block_until_ready(worker.runner.kv_caches) + logger.info("start_load_kv completed and blocked until ready.") + + # verify the data + # we will donate the original kv_cache ref + dst_kv_cache = worker.runner.kv_caches + for src_block_id in src_block_ids: + for layer_idx in range(self.num_layers): + self.assertArraysEqual( + np.array(src_kv_cache[layer_idx][src_block_id]), + np.array(dst_kv_cache[layer_idx][src_block_id])) diff --git a/tpu_inference/distributed/offload/cpu_backend.py b/tpu_inference/distributed/offload/cpu_backend.py index e94939d18..37352c504 100644 --- a/tpu_inference/distributed/offload/cpu_backend.py +++ b/tpu_inference/distributed/offload/cpu_backend.py @@ -31,11 +31,15 @@ def __init__(self, num_cpu_chunks: int): self.max_num_cpu_chunks = num_cpu_chunks self.cache: OrderedDict[CpuChunkId, Any] = OrderedDict() self.current_size_bytes = 0 - self.num_occupied_cpu_chunks = 0 + self._num_saved_cpu_chunks = 0 logger.info( "LocalCPUBackend initialized." f"CPU cache capacity: {self.max_num_cpu_chunks} chunks / pages.") + @property + def num_saved_cpu_chunks(self) -> int: + return self._num_saved_cpu_chunks + def _get_value_size(self, value: Any) -> int: """Calculates the size of a cache value in bytes.""" size_in_bytes = 0 @@ -66,16 +70,16 @@ def add(self, chunk_id: CpuChunkId, value: Any) -> bool: old_value = self.cache.pop(chunk_id) self.current_size_bytes -= self._get_value_size(old_value) del old_value - self.num_occupied_cpu_chunks -= 1 + self._num_saved_cpu_chunks -= 1 self.cache[chunk_id] = value - self.num_occupied_cpu_chunks += 1 + self._num_saved_cpu_chunks += 1 value_size = self._get_value_size(value) self.current_size_bytes += value_size logger.info( f"Added chunk_id: {chunk_id} (size:{value_size}) to CPU backend.") logger.info( - f"Cache: {self.current_size_bytes} bytes, {self.num_occupied_cpu_chunks} occupied chunks." + f"Cache: {self.current_size_bytes} bytes, {self._num_saved_cpu_chunks} occupied chunks." ) return True diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index f0a9985b3..f7459df79 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -1616,6 +1616,7 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], f"extracted_blocks_tpu: {flat_kv_caches_tpu[0].shape}, {flat_kv_caches_tpu[0].sharding}" ) + chunks_on_cpu = None if self.use_bucketed_swap_ops: chunks_on_cpu = self._bucketed_swap_out_fn(flat_kv_caches_tpu) else: From 3551c9d3f588315d8ab5ac8307cfeb8da09bd2e6 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 20 Nov 2025 05:39:27 +0000 Subject: [PATCH 10/19] rename Signed-off-by: Juncheng Gu --- .../offload/cpu_offloading_worker_test.py | 370 ------------------ .../tpu_offload_connector_worker_test.py | 266 ++++++++++++- 2 files changed, 256 insertions(+), 380 deletions(-) delete mode 100644 tests/distributed/offload/cpu_offloading_worker_test.py diff --git a/tests/distributed/offload/cpu_offloading_worker_test.py b/tests/distributed/offload/cpu_offloading_worker_test.py deleted file mode 100644 index 26df906d0..000000000 --- a/tests/distributed/offload/cpu_offloading_worker_test.py +++ /dev/null @@ -1,370 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import functools -import os -import random -from typing import List - -import jax -import jax.numpy as jnp -import numpy as np -from absl.testing import parameterized -from jax._src import compilation_cache as cc -from jax._src import test_util as jtu -from jax.sharding import Mesh, NamedSharding, PartitionSpec -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole - -from tpu_inference.distributed.offload.tpu_offload_connector import (LoadSpec, - SaveSpec) -from tpu_inference.distributed.offload.tpu_offload_connector import \ - TPUOffloadConnector as CPUOffloadingConnector -from tpu_inference.distributed.offload.tpu_offload_connector import ( - TPUOffloadConnectorMetadata, TPUReqMeta) -from tpu_inference.logger import init_logger -from tpu_inference.runner.tpu_jax_runner import TPUModelRunner - -logger = init_logger(__name__) - -_DEFAULT_BLOCK_SIZE = 64 - - -class MockTPUModelRunner(TPUModelRunner): - """A mock TPUModelRunner for testing purposes.""" - - def __init__(self, kv_caches: List[jax.Array], mesh: Mesh): - self.kv_caches = kv_caches - self.mesh = mesh - self.model_config = None - self.sampler = None - self.devices = jax.devices() - - def get_kv_cache_layout(self): - return "NHD" - - -class MockVllmConfig: - - def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): - self.model_config = self.Model() - self.cache_config = self.Cache(block_size) - - class Model: - model = "test-model" - - class Cache: - - def __init__(self, block_size): - self.block_size = block_size - - -class TestCpuOffloadingSave(jtu.JaxTestCase): - """Test the save functionality of the TPUOffloadConnectorWorker.""" - - def setUp(self): - super().setUp() - self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE) - self.num_layers = 2 - self.num_blocks = 24 - self.num_cpu_chunks = 24 - self.block_size = self.vllm_config.cache_config.block_size - self.num_heads = 8 - self.head_size = 128 - self.mesh = self.create_mesh((1, 8), ("data", "model")) - if self.mesh is None: - self.skipTest("Cannot create mesh. Must be run on a TPU node.") - return - - # Define cache properties - self.cache_shape = ( - self.num_blocks, - self.block_size, - self.num_heads, - 2, - self.head_size, - ) - self.cache_dtype = jnp.bfloat16 - partition_spec = PartitionSpec(None, None, "model") - self.device_sharding = NamedSharding(self.mesh, partition_spec) - - def tearDown(self): - super().tearDown() - cc.reset_cache() - - def create_mesh(self, axis_shapes, axis_names): - """Creates a JAX device mesh with the default device order.""" - try: - num_required_devices = np.prod(axis_shapes) - devices = np.array(jax.devices()) - if len(devices) < num_required_devices: - self.skipTest( - f"Not enough devices to create mesh of shape {axis_shapes}." - ) - device_array = devices[:num_required_devices].reshape(axis_shapes) - return jax.sharding.Mesh(device_array, axis_names) - except RuntimeError: - return None - - def _create_connector(self, - swap_op_type: str = "jax", - use_precompiled_swap_ops: bool = False): - os.environ["TPU_OFFLOAD_SWAP_OP_TYPE"] = swap_op_type - os.environ[ - "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1" - os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str(self.num_cpu_chunks) - - connector = CPUOffloadingConnector(self.vllm_config, - KVConnectorRole.WORKER) - worker = connector.connector_worker - assert worker is not None - - @functools.partial(jax.jit, out_shardings=self.device_sharding) - def create_on_device(key): - return jax.random.uniform(key, - shape=self.cache_shape, - dtype=self.cache_dtype) - - source_kv_cache = [ - create_on_device(jax.random.key(i)) for i in range(self.num_layers) - ] - jax.block_until_ready(source_kv_cache) - - mock_runner = MockTPUModelRunner(kv_caches=source_kv_cache, - mesh=self.mesh) - worker.register_runner(mock_runner) - return connector - - @parameterized.named_parameters( - dict( - testcase_name="_regular_multi_block_save", - num_blocks_to_save=5, - ), - dict( - testcase_name="_regular_multi_block_save_with_compile_jax", - num_blocks_to_save=5, - use_precompiled_swap_ops=True, - ), - dict( - testcase_name="_regular_multi_block_save_with_compile_pallas", - num_blocks_to_save=5, - use_precompiled_swap_ops=True, - swap_op_type="pallas", - ), - dict( - testcase_name="_final_save", - num_blocks_to_save=1, - is_final_save=True, - skip_save=False, - ), - dict( - testcase_name="_final_skip_save", - num_blocks_to_save=0, - is_final_save=True, - skip_save=True, - ), - ) - def test_tpu_connector_save( - self, - num_blocks_to_save: int, - is_final_save: bool = False, - skip_save: bool = False, - use_precompiled_swap_ops: bool = False, - swap_op_type: str = "jax", - ): - if num_blocks_to_save > self.num_blocks or num_blocks_to_save > self.num_cpu_chunks: - self.skipTest( - f"num_blocks_to_save {num_blocks_to_save} exceeds ModelRunner / OffloadConnectorWorker's capacity" - ) - - # Prepare and Execute Save - all_block_ids = list(range(self.num_blocks)) - all_chunk_ids = list(range(self.num_cpu_chunks)) - src_block_ids = random.sample(all_block_ids, num_blocks_to_save) - dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_save) - num_tokens_to_save = num_blocks_to_save * self.block_size - num_total_tokens = num_tokens_to_save - save_spec = SaveSpec( - num_skip_leading_tokens=0, - num_total_tokens=num_total_tokens, - is_final_save=is_final_save, - skip_save=skip_save, - src_blocks=src_block_ids, - dst_chunks=dst_chunk_ids, - ) - - logger.info(f"Starting test_tpu_connector_save with: " - f"num_blocks_to_save={num_blocks_to_save}, " - f"is_final_save={is_final_save}, " - f"skip_save={skip_save}, " - f"use_precompiled_swap_ops={use_precompiled_swap_ops}, " - f"swap_op_type={swap_op_type};" - f"Swapspec: {save_spec}") - - total_token_ids = list(range(num_total_tokens)) - - req_id = "save_req" - req_meta = TPUReqMeta( - req_id=req_id, - token_ids=total_token_ids, - local_block_ids=src_block_ids, - save_spec=save_spec, - ) - - connector_metadata = TPUOffloadConnectorMetadata( - requests_meta=[req_meta]) - - connector = self._create_connector(swap_op_type, - use_precompiled_swap_ops) - worker = connector.connector_worker - connector.bind_connector_metadata(connector_metadata) - logger.info( - "Connector metadata bound, calling worker.wait_for_save().") - worker.wait_for_save() - logger.info("worker.wait_for_save() completed.") - - # Verification - logger.info("Starting verification phase.") - cpu_backend = worker.cpu_backend - kv_caches = worker.runner.kv_caches - - if skip_save or num_tokens_to_save == 0: - logger.info(" no blocks to save") - assert cpu_backend.num_saved_cpu_chunks == 0 - self.assertEmpty(worker.finished_save_reqs) - self.assertEmpty(worker.offload_stats.data["finished_save_blocks"]) - return - - # verify the saved chunks - assert req_id in worker.offload_stats.data["finished_save_blocks"] - assert dst_chunk_ids == worker.offload_stats.data[ - "finished_save_blocks"][req_id] - - for tpu_block_id, cpu_chunk_id in zip(src_block_ids, dst_chunk_ids): - cpu_kv_chunk = cpu_backend.get(cpu_chunk_id) - for layer_idx in range(self.num_layers): - tpu_kv_block = kv_caches[layer_idx][tpu_block_id] - self.assertArraysEqual(np.array(tpu_kv_block), - np.array(cpu_kv_chunk[layer_idx])) - - logger.info("Saved data verification completed.") - - if is_final_save: - finished_saves, _ = worker.get_finished() - logger.info( - f"is_final_save is True. Finished requests: {finished_saves}") - self.assertIn(req_id, finished_saves) - - @parameterized.named_parameters( - dict( - testcase_name="_single_block_", - num_blocks_to_operate=1, - ), - dict( - testcase_name="_multi_blocks_compile_jax", - num_blocks_to_operate=5, - use_precompiled_swap_ops=True, - swap_op_type="jax", - ), - dict( - testcase_name="_multi_blocks_compile_pallas", - num_blocks_to_operate=5, - use_precompiled_swap_ops=True, - swap_op_type="pallas", - ), - ) - def test_tpu_connector_load( - self, - num_blocks_to_operate: int, - use_precompiled_swap_ops: bool = False, - swap_op_type: str = "jax", - ): - """ - This test simulates a scenario where some amount of blocks get - offloaded to cpu cache, and then get loaded into tpu kv cache. - Both swap-out and swap-in are tested. - - Steps: - 1. Setup: - 2. Simulate a save operation - 3. Load the data - 4. Verification - """ - if num_blocks_to_operate > self.num_blocks or num_blocks_to_operate > self.num_cpu_chunks: - self.skipTest( - f"num_blocks_to_save {num_blocks_to_operate} exceeds ModelRunner / OffloadConnectorWorker's capacity" - ) - # 1. Setup - connector = self._create_connector(swap_op_type, - use_precompiled_swap_ops) - worker = connector.connector_worker - # Ground truth cache on TPU - src_kv_cache = worker.runner.kv_caches - # Destination cache on TPU, should be modified by the load operation - dst_kv_cache = [ - jax.device_put(jnp.zeros(self.cache_shape, dtype=self.cache_dtype), - self.device_sharding) - for _ in range(self.num_layers) - ] - jax.block_until_ready(dst_kv_cache) - - # Prepare - all_block_ids = list(range(self.num_blocks)) - all_chunk_ids = list(range(self.num_cpu_chunks)) - src_block_ids = random.sample(all_block_ids, num_blocks_to_operate) - dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_operate) - num_tokens_to_save = num_blocks_to_operate * self.block_size - num_total_tokens = num_tokens_to_save - save_spec = SaveSpec( - num_skip_leading_tokens=0, - num_total_tokens=num_tokens_to_save, - is_final_save=False, - skip_save=False, - src_blocks=src_block_ids, - dst_chunks=dst_chunk_ids, - ) - total_token_ids = list(range(num_total_tokens)) - req_id = "save_req" - req_meta = TPUReqMeta( - req_id=req_id, - token_ids=total_token_ids, - local_block_ids=src_block_ids, - save_spec=save_spec, - ) - connector_metadata = TPUOffloadConnectorMetadata( - requests_meta=[req_meta]) - connector.bind_connector_metadata(connector_metadata) - logger.info( - "Connector metadata bound, calling worker.wait_for_save().") - worker.wait_for_save() - logger.info("worker.wait_for_save() completed.") - - # 3. Prepare and Execute Delta Load - worker.runner.kv_caches = dst_kv_cache - load_spec = LoadSpec( - num_matched_tokens=num_tokens_to_save, - dst_blocks=src_block_ids, - src_chunks=dst_chunk_ids, - can_load=True, - num_skip_leading_tokens=0, - ) - req_meta = TPUReqMeta( - req_id="load_req", - token_ids=total_token_ids, - local_block_ids=src_block_ids, - load_spec=load_spec, - ) - connector_metadata = TPUOffloadConnectorMetadata( - requests_meta=[req_meta]) - connector.bind_connector_metadata(connector_metadata) - logger.info("Connector metadata bound, calling start_load_kv.") - worker.start_load_kv(fwd_ctx=None) - jax.block_until_ready(worker.runner.kv_caches) - logger.info("start_load_kv completed and blocked until ready.") - - # verify the data - # we will donate the original kv_cache ref - dst_kv_cache = worker.runner.kv_caches - for src_block_id in src_block_ids: - for layer_idx in range(self.num_layers): - self.assertArraysEqual( - np.array(src_kv_cache[layer_idx][src_block_id]), - np.array(dst_kv_cache[layer_idx][src_block_id])) diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py index 53121a717..ef7935d34 100644 --- a/tests/distributed/offload/tpu_offload_connector_worker_test.py +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -2,6 +2,7 @@ import functools import os +import random from typing import List import jax @@ -13,8 +14,12 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole +from tpu_inference.distributed.offload.tpu_offload_connector import (LoadSpec, + SaveSpec) from tpu_inference.distributed.offload.tpu_offload_connector import \ TPUOffloadConnector as CPUOffloadingConnector +from tpu_inference.distributed.offload.tpu_offload_connector import ( + TPUOffloadConnectorMetadata, TPUReqMeta) from tpu_inference.logger import init_logger from tpu_inference.runner.tpu_jax_runner import TPUModelRunner @@ -52,14 +57,15 @@ def __init__(self, block_size): self.block_size = block_size -class TestTPUOffloadWorkerPrecompile(jtu.JaxTestCase): - """Test the host offloading precompilation and related functionalities.""" +class TestCpuOffloadingSave(jtu.JaxTestCase): + """Test the save functionality of the TPUOffloadConnectorWorker.""" def setUp(self): super().setUp() self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE) self.num_layers = 2 - self.num_blocks = 128 # Increased for larger tests + self.num_blocks = 24 + self.num_cpu_chunks = 24 self.block_size = self.vllm_config.cache_config.block_size self.num_heads = 8 self.head_size = 128 @@ -98,9 +104,14 @@ def create_mesh(self, axis_shapes, axis_names): except RuntimeError: return None - def _create_connector(self, swap_op_type: str = "jax"): - # Clean the singleton backend instance before each test + def _create_connector(self, + swap_op_type: str = "jax", + use_precompiled_swap_ops: bool = False): os.environ["TPU_OFFLOAD_SWAP_OP_TYPE"] = swap_op_type + os.environ[ + "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1" + os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str(self.num_cpu_chunks) + connector = CPUOffloadingConnector(self.vllm_config, KVConnectorRole.WORKER) worker = connector.connector_worker @@ -149,8 +160,7 @@ def test_decompose_into_buckets(self, num_blocks: int, """ Tests the _decompose_into_buckets function for correct greedy decomposition. """ - os.environ["TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" - connector = self._create_connector() + connector = self._create_connector(use_precompiled_swap_ops="0") worker = connector.connector_worker self.assertEqual(worker._decompose_into_buckets(num_blocks), expected_buckets) @@ -167,9 +177,9 @@ def test_precompile_run_success(self, swap_op_type: str): Tests that _precompile_kv_swap_operations runs without errors and modifies the cache content. """ - # Unset skip flag to allow precompilation to run - os.environ["TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" - connector = self._create_connector(swap_op_type=swap_op_type) + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops="0") + worker = connector.connector_worker # Keep a copy of the original cache content on the host @@ -187,3 +197,239 @@ def test_precompile_run_success(self, swap_op_type: str): for orig, new in zip(original_cache_host, new_cache_host)), "Cache content should not have changed after precompilation.", ) + + @parameterized.named_parameters( + dict( + testcase_name="_regular_multi_block_save", + num_blocks_to_save=5, + ), + dict( + testcase_name="_regular_multi_block_save_with_compile_jax", + num_blocks_to_save=5, + use_precompiled_swap_ops=True, + ), + dict( + testcase_name="_regular_multi_block_save_with_compile_pallas", + num_blocks_to_save=5, + use_precompiled_swap_ops=True, + swap_op_type="pallas", + ), + dict( + testcase_name="_final_save", + num_blocks_to_save=1, + is_final_save=True, + skip_save=False, + ), + dict( + testcase_name="_final_skip_save", + num_blocks_to_save=0, + is_final_save=True, + skip_save=True, + ), + ) + def test_tpu_connector_save( + self, + num_blocks_to_save: int, + is_final_save: bool = False, + skip_save: bool = False, + use_precompiled_swap_ops: bool = False, + swap_op_type: str = "jax", + ): + if num_blocks_to_save > self.num_blocks or num_blocks_to_save > self.num_cpu_chunks: + self.skipTest( + f"num_blocks_to_save {num_blocks_to_save} exceeds ModelRunner / OffloadConnectorWorker's capacity" + ) + + # Prepare and Execute Save + all_block_ids = list(range(self.num_blocks)) + all_chunk_ids = list(range(self.num_cpu_chunks)) + src_block_ids = random.sample(all_block_ids, num_blocks_to_save) + dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_save) + num_tokens_to_save = num_blocks_to_save * self.block_size + num_total_tokens = num_tokens_to_save + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=num_total_tokens, + is_final_save=is_final_save, + skip_save=skip_save, + src_blocks=src_block_ids, + dst_chunks=dst_chunk_ids, + ) + + logger.info(f"Starting test_tpu_connector_save with: " + f"num_blocks_to_save={num_blocks_to_save}, " + f"is_final_save={is_final_save}, " + f"skip_save={skip_save}, " + f"use_precompiled_swap_ops={use_precompiled_swap_ops}, " + f"swap_op_type={swap_op_type};" + f"Swapspec: {save_spec}") + + total_token_ids = list(range(num_total_tokens)) + + req_id = "save_req" + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_block_ids, + save_spec=save_spec, + ) + + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) + + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops) + worker = connector.connector_worker + connector.bind_connector_metadata(connector_metadata) + logger.info( + "Connector metadata bound, calling worker.wait_for_save().") + worker.wait_for_save() + logger.info("worker.wait_for_save() completed.") + + # Verification + logger.info("Starting verification phase.") + cpu_backend = worker.cpu_backend + kv_caches = worker.runner.kv_caches + + if skip_save or num_tokens_to_save == 0: + logger.info(" no blocks to save") + assert cpu_backend.num_saved_cpu_chunks == 0 + self.assertEmpty(worker.finished_save_reqs) + self.assertEmpty(worker.offload_stats.data["finished_save_blocks"]) + return + + # verify the saved chunks + assert req_id in worker.offload_stats.data["finished_save_blocks"] + assert dst_chunk_ids == worker.offload_stats.data[ + "finished_save_blocks"][req_id] + + for tpu_block_id, cpu_chunk_id in zip(src_block_ids, dst_chunk_ids): + cpu_kv_chunk = cpu_backend.get(cpu_chunk_id) + for layer_idx in range(self.num_layers): + tpu_kv_block = kv_caches[layer_idx][tpu_block_id] + self.assertArraysEqual(np.array(tpu_kv_block), + np.array(cpu_kv_chunk[layer_idx])) + + logger.info("Saved data verification completed.") + + if is_final_save: + finished_saves, _ = worker.get_finished() + logger.info( + f"is_final_save is True. Finished requests: {finished_saves}") + self.assertIn(req_id, finished_saves) + + @parameterized.named_parameters( + dict( + testcase_name="_single_block_", + num_blocks_to_operate=1, + ), + dict( + testcase_name="_multi_blocks_compile_jax", + num_blocks_to_operate=5, + use_precompiled_swap_ops=True, + swap_op_type="jax", + ), + dict( + testcase_name="_multi_blocks_compile_pallas", + num_blocks_to_operate=5, + use_precompiled_swap_ops=True, + swap_op_type="pallas", + ), + ) + def test_tpu_connector_load( + self, + num_blocks_to_operate: int, + use_precompiled_swap_ops: bool = False, + swap_op_type: str = "jax", + ): + """ + This test simulates a scenario where some amount of blocks get + offloaded to cpu cache, and then get loaded into tpu kv cache. + Both swap-out and swap-in are tested. + + Steps: + 1. Setup: + 2. Simulate a save operation + 3. Load the data + 4. Verification + """ + if num_blocks_to_operate > self.num_blocks or num_blocks_to_operate > self.num_cpu_chunks: + self.skipTest( + f"num_blocks_to_save {num_blocks_to_operate} exceeds ModelRunner / OffloadConnectorWorker's capacity" + ) + # 1. Setup + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops) + worker = connector.connector_worker + # Ground truth cache on TPU + src_kv_cache = worker.runner.kv_caches + # Destination cache on TPU, should be modified by the load operation + dst_kv_cache = [ + jax.device_put(jnp.zeros(self.cache_shape, dtype=self.cache_dtype), + self.device_sharding) + for _ in range(self.num_layers) + ] + jax.block_until_ready(dst_kv_cache) + + # Prepare + all_block_ids = list(range(self.num_blocks)) + all_chunk_ids = list(range(self.num_cpu_chunks)) + src_block_ids = random.sample(all_block_ids, num_blocks_to_operate) + dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_operate) + num_tokens_to_save = num_blocks_to_operate * self.block_size + num_total_tokens = num_tokens_to_save + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=num_tokens_to_save, + is_final_save=False, + skip_save=False, + src_blocks=src_block_ids, + dst_chunks=dst_chunk_ids, + ) + total_token_ids = list(range(num_total_tokens)) + req_id = "save_req" + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_block_ids, + save_spec=save_spec, + ) + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) + connector.bind_connector_metadata(connector_metadata) + logger.info( + "Connector metadata bound, calling worker.wait_for_save().") + worker.wait_for_save() + logger.info("worker.wait_for_save() completed.") + + # 3. Prepare and Execute Delta Load + worker.runner.kv_caches = dst_kv_cache + load_spec = LoadSpec( + num_matched_tokens=num_tokens_to_save, + dst_blocks=src_block_ids, + src_chunks=dst_chunk_ids, + can_load=True, + num_skip_leading_tokens=0, + ) + req_meta = TPUReqMeta( + req_id="load_req", + token_ids=total_token_ids, + local_block_ids=src_block_ids, + load_spec=load_spec, + ) + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) + connector.bind_connector_metadata(connector_metadata) + logger.info("Connector metadata bound, calling start_load_kv.") + worker.start_load_kv(fwd_ctx=None) + jax.block_until_ready(worker.runner.kv_caches) + logger.info("start_load_kv completed and blocked until ready.") + + # verify the data + # we will donate the original kv_cache ref + dst_kv_cache = worker.runner.kv_caches + for src_block_id in src_block_ids: + for layer_idx in range(self.num_layers): + self.assertArraysEqual( + np.array(src_kv_cache[layer_idx][src_block_id]), + np.array(dst_kv_cache[layer_idx][src_block_id])) From 17647673c19e7d1f31cadf66f6b839ea2184ebc0 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 20 Nov 2025 05:46:27 +0000 Subject: [PATCH 11/19] rename Signed-off-by: Juncheng Gu --- .../offload/tpu_offload_connector.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index f7459df79..a8ce1e462 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -277,20 +277,20 @@ def __post_init__(self): def reset(self): # Must be serializable self.data: dict[str, dict[str, list[int]]] = { - "finished_save_blocks": dict(), - "finished_load_blocks": dict(), + "finished_save_chunks": dict(), + "finished_load_chunks": dict(), } def record_save(self, req: ReqId, saved_chunk_ids: list[int]): - if req not in self.data["finished_save_blocks"]: - self.data["finished_save_blocks"][req] = [] - self.data["finished_save_blocks"][req].extend( + if req not in self.data["finished_save_chunks"]: + self.data["finished_save_chunks"][req] = [] + self.data["finished_save_chunks"][req].extend( copy.deepcopy(saved_chunk_ids)) def record_load(self, req: ReqId, loaded_chunk_ids: list[int]): - if req not in self.data["finished_load_blocks"]: - self.data["finished_load_blocks"][req] = [] - self.data["finished_load_blocks"][req].extend( + if req not in self.data["finished_load_chunks"]: + self.data["finished_load_chunks"][req] = [] + self.data["finished_load_chunks"][req].extend( copy.deepcopy(loaded_chunk_ids)) def clone_and_reset(self) -> "KVOffloadConnectorStats": @@ -312,18 +312,18 @@ def reduce(self) -> dict[str, int | float]: "Num finished load blocks ": 0, } - finished_save_blocks = sum(self.data["finished_save_blocks"].values()) - finished_load_blocks = sum(self.data["finished_load_blocks"].values()) + finished_save_chunks = sum(self.data["finished_save_chunks"].values()) + finished_load_chunks = sum(self.data["finished_load_chunks"].values()) return { - "Num finished save blocks ": finished_save_blocks, - "Num finished load blocks ": finished_load_blocks, + "Num finished save chunks ": finished_save_chunks, + "Num finished load chunks": finished_load_chunks, } @property def num_finished_blocks(self) -> int: - return len(self.data["finished_save_blocks"]) + len( - self.data["finished_load_blocks"]) + return len(self.data["finished_save_chunks"]) + len( + self.data["finished_load_chunks"]) # The metadata used for communicating between scheduler and worker connectors. @@ -1096,13 +1096,13 @@ def update_connector_output(self, connector_output: KVConnectorOutput): if connector_output.kv_connector_stats and connector_output.kv_connector_stats.data is not None: assert isinstance(connector_output.kv_connector_stats, KVOffloadConnectorStats) - assert "finished_save_blocks" in connector_output.kv_connector_stats.data - assert "finished_load_blocks" in connector_output.kv_connector_stats.data + assert "finished_save_chunks" in connector_output.kv_connector_stats.data + assert "finished_load_chunks" in connector_output.kv_connector_stats.data for req_id, saved_chunk_ids in connector_output.kv_connector_stats.data[ - "finished_save_blocks"].items(): + "finished_save_chunks"].items(): num_saved_chunks = len(saved_chunk_ids) logger.info( - f" finished_save_blocks for {req_id}: {saved_chunk_ids}") + f" finished_save_chunks for {req_id}: {saved_chunk_ids}") # free staging blocks self.staging_buffer_manager.free( req_id, usage="save", num_finished_blocks=num_saved_chunks) @@ -1116,10 +1116,10 @@ def update_connector_output(self, connector_output: KVConnectorOutput): self.offload_manager.mark_completion(saved_chunk_ids, "save") for req_id, loaded_chunk_ids in connector_output.kv_connector_stats.data[ - "finished_load_blocks"].items(): + "finished_load_chunks"].items(): num_loaded_chunks = len(loaded_chunk_ids) logger.info( - f" finished_load_blocks for {req_id}: {num_loaded_chunks}" + f" finished_load_chunks for {req_id}: {num_loaded_chunks}" ) self.staging_buffer_manager.free( req_id, From 16ff9517e3675a75e081bbad853c2bf6f425d9cc Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 20 Nov 2025 05:53:59 +0000 Subject: [PATCH 12/19] update gke unit tests Signed-off-by: Juncheng Gu --- examples/gke/pod_tpu_host_offload_unit_tests.yaml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/gke/pod_tpu_host_offload_unit_tests.yaml index 5a9fb73db..641bc8ec4 100644 --- a/examples/gke/pod_tpu_host_offload_unit_tests.yaml +++ b/examples/gke/pod_tpu_host_offload_unit_tests.yaml @@ -17,12 +17,10 @@ spec: command: - /bin/bash - -c - - "pytest -sv tests/distributed/offload/host_offloading_precompile_test.py" - # - "pytest -sv tests/distributed/offload/cpu_offloading_worker_test.py" - # - "pytest -sv tests/distributed/offload/cpu_offloading_cache_util_test.py" - # - "pytest -sv tests/distributed/offload/host_offloading_accuracy_test.py" - # - "pytest -sv tests/distributed/offload/local_cpu_backend_test.py" - # - "pytest -sv tests/distributed/offload/host_offloading_precompile_test.py" + - "pytest -sv tests/distributed/offload/cpu_backend_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_connector_worker_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_utils_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_accuracy_test.py" env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: From e9c45ac6b7d5f3e03f1777c4311be693f77912ac Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 20 Nov 2025 06:38:49 +0000 Subject: [PATCH 13/19] offload manager tests Signed-off-by: Juncheng Gu --- .../gke/pod_tpu_host_offload_unit_tests.yaml | 1 + .../offload/cpu_offloading_scheduler_test.py | 840 ------------------ .../offload/tpu_offload_manager_test.py | 357 ++++++++ .../distributed/offload/offload_manager.py | 40 +- 4 files changed, 381 insertions(+), 857 deletions(-) delete mode 100644 tests/distributed/offload/cpu_offloading_scheduler_test.py create mode 100644 tests/distributed/offload/tpu_offload_manager_test.py diff --git a/examples/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/gke/pod_tpu_host_offload_unit_tests.yaml index 641bc8ec4..8daf6c035 100644 --- a/examples/gke/pod_tpu_host_offload_unit_tests.yaml +++ b/examples/gke/pod_tpu_host_offload_unit_tests.yaml @@ -20,6 +20,7 @@ spec: - "pytest -sv tests/distributed/offload/cpu_backend_test.py" - "pytest -sv tests/distributed/offload/tpu_offload_connector_worker_test.py" - "pytest -sv tests/distributed/offload/tpu_offload_utils_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_manager_test.py" - "pytest -sv tests/distributed/offload/tpu_offload_accuracy_test.py" env: - name: HUGGING_FACE_HUB_TOKEN diff --git a/tests/distributed/offload/cpu_offloading_scheduler_test.py b/tests/distributed/offload/cpu_offloading_scheduler_test.py deleted file mode 100644 index fda3cb7f8..000000000 --- a/tests/distributed/offload/cpu_offloading_scheduler_test.py +++ /dev/null @@ -1,840 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import os -from unittest.mock import MagicMock - -import pytest -from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput -from vllm.v1.request import Request - -from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend -from tpu_inference.distributed.offload.tpu_offload_connector import ( - ReqId, RequestTracker, Scheduler, _StagingBufferManager) -from tpu_inference.logger import init_logger - -from .cpu_offloading_worker_test import MockVllmConfig - -logger = init_logger(__name__) - -_DEFAULT_BLOCK_SIZE = 4 - - -def create_request( - request_id: str, - prompt_token_ids: list[int], - block_size: int, - num_computed_tokens: int = 0, -) -> Request: - """Creates a mock vLLM request object.""" - req = MagicMock(spec=Request) - req.request_id = request_id - req.req_id = request_id - req.prompt_token_ids = prompt_token_ids - req.all_token_ids = prompt_token_ids - req.num_computed_tokens = num_computed_tokens - req.block_size = block_size - req.block_ids = [[]] # Mock structure - return req - - -@pytest.fixture -def clean_backend_instance(): - """ - Provides a clean instance of the LocalCPUBackend for each test. - """ - LocalCPUBackend._instance = None - LocalCPUBackend._initialized = False - yield - LocalCPUBackend._instance = None - LocalCPUBackend._initialized = False - - -@pytest.fixture -def scheduler_factory(): - """Provides a factory function for Scheduler instances.""" - - def _scheduler( - block_size: int = _DEFAULT_BLOCK_SIZE, - offload_decode_save: int = 0, - offload_partial_block_save_behavior: str = "drop", - offload_partial_block_dynamic_pad_lower_limit: int = 0, - offload_staging_buffer_tokens: int = -1, - ): - # update config - vllm_config = MockVllmConfig(block_size=block_size) - os.environ["TPU_OFFLOAD_DECODE_SAVE"] = str(offload_decode_save) - os.environ[ - "TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR"] = offload_partial_block_save_behavior - os.environ["TPU_OFFLOAD_PARTIAL_BLOCK_DYNAMIC_PAD_LOWER_LIMIT"] = str( - offload_partial_block_dynamic_pad_lower_limit) - if offload_staging_buffer_tokens >= 0: - os.environ["TPU_OFFLOAD_STAGING_BUFFER_TOKENS"] = str( - offload_staging_buffer_tokens) - return Scheduler(vllm_config) - - return _scheduler - - -class TestStagingBufferManager: - - def test_initialization(self): - manager = _StagingBufferManager(num_blocks=100) - assert manager.num_blocks == 100 - assert manager.get_num_free_staging_blocks() == 100 - assert manager.get_num_used_staging_blocks() == 0 - - def test_allocate_simple(self): - manager = _StagingBufferManager(num_blocks=100) - req_id1: ReqId = "req1" - req_id2: ReqId = "req2" - - allocated1 = manager.allocate(req_id1, 10, "load") - assert allocated1 == 10 - assert manager.get_num_free_staging_blocks() == 90 - assert manager.get_num_used_staging_blocks() == 10 - assert manager._num_blocks_for_load == 10 - assert manager._num_blocks_for_save == 0 - - allocated2 = manager.allocate(req_id2, 20, "save") - assert allocated2 == 20 - assert manager.get_num_free_staging_blocks() == 70 - assert manager.get_num_used_staging_blocks() == 30 - assert manager._num_blocks_for_load == 10 - assert manager._num_blocks_for_save == 20 - - def test_allocate_insufficient_capacity(self): - manager = _StagingBufferManager(num_blocks=10) - req_id: ReqId = "req1" - allocated = manager.allocate(req_id, 20, "load") - assert allocated == 0 - assert manager.get_num_free_staging_blocks() == 10 - assert manager.get_num_used_staging_blocks() == 0 - - def test_allocate_existing_load_request(self): - manager = _StagingBufferManager(num_blocks=100) - req_id: ReqId = "req1" - manager.allocate(req_id, 10, "load") - with pytest.raises(ValueError): - # multiple concurrent loads from a single request is not allowed. - manager.allocate(req_id, 5, "load") - - def test_allocate_existing_save_request(self): - manager = _StagingBufferManager(num_blocks=100) - req_id: ReqId = "req1" - manager.allocate(req_id, 10, "save") - assert manager._blocks_for_save[req_id] == 10 - manager.allocate(req_id, 5, "save") - assert manager._blocks_for_save[req_id] == 15 - assert manager.get_num_free_staging_blocks() == 85 - assert manager.get_num_used_staging_blocks() == 15 - - def test_allocate_negative_blocks(self): - manager = _StagingBufferManager(num_blocks=100) - req_id: ReqId = "req1" - allocated = manager.allocate(req_id, -5, "load") - assert allocated == -5 - assert manager.get_num_free_staging_blocks() == 100 - - def test_free_full(self): - manager = _StagingBufferManager(num_blocks=100) - req_id: ReqId = "req1" - manager.allocate(req_id, 10, "load") - freed = manager.free(req_id, "load") - assert freed == 10 - assert manager.get_num_free_staging_blocks() == 100 - assert manager.get_num_used_staging_blocks() == 0 - assert req_id not in manager._blocks_for_load - - def test_free_partial(self): - manager = _StagingBufferManager(num_blocks=100) - req_id: ReqId = "req1" - manager.allocate(req_id, 10, "save") - freed = manager.free(req_id, "save", num_finished_blocks=4) - assert freed == 4 - assert manager.get_num_free_staging_blocks() == 94 - assert manager.get_num_used_staging_blocks() == 6 - assert manager._blocks_for_save[req_id] == 6 - - def test_free_more_than_allocated(self): - manager = _StagingBufferManager(num_blocks=100) - req_id: ReqId = "req1" - manager.allocate(req_id, 10, "load") - manager.free(req_id, "load", num_finished_blocks=15) - assert req_id not in manager._blocks_for_load - - def test_free_non_existent_request(self): - manager = _StagingBufferManager(num_blocks=100) - req_id: ReqId = "req1" - freed = manager.free(req_id, "load") - assert freed == 0 - - def test_get_usage(self): - manager = _StagingBufferManager(num_blocks=100) - req_id1: ReqId = "req1" - req_id2: ReqId = "req2" - manager.allocate(req_id1, 10, "load") - manager.allocate(req_id2, 20, "save") - - usage_str = manager.get_usage() - expected_str = "Staging Buffer: total=100, free=70, used_for_load=10, used_for_save=20;" - assert usage_str == expected_str - - usage_str_details = manager.get_usage(with_details=True) - assert "save_details:{req2:20,}" in usage_str_details - assert "load_details:{req1:10,}" in usage_str_details - - def test_complex_scenario(self): - manager = _StagingBufferManager(num_blocks=50) - req1, req2, req3 = "req1", "req2", "req3" - - # req1 loads 10, req2 saves 15 - assert manager.allocate(req1, 10, "load") == 10 - assert manager.allocate(req2, 15, "save") == 15 - assert manager.get_num_free_staging_blocks() == 25 - assert manager.get_num_used_staging_blocks() == 25 - - # req3 tries to load 30, fails - assert manager.allocate(req3, 30, "load") == 0 - assert manager.get_num_free_staging_blocks() == 25 - - # req1 finishes loading - assert manager.free(req1, "load") == 10 - assert manager.get_num_free_staging_blocks() == 35 - - # req3 can now load 20 - assert manager.allocate(req3, 20, "load") == 20 - assert manager.get_num_free_staging_blocks() == 15 - assert manager.get_num_used_staging_blocks( - ) == 35 # 15 for save (req2) + 20 for load (req3) - - # req2 saves another 5 - assert manager.allocate(req2, 5, "save") == 5 - assert manager.get_num_free_staging_blocks() == 10 - assert manager._blocks_for_save[req2] == 20 - - # req2 frees 8 blocks - assert manager.free(req2, "save", 8) == 8 - assert manager.get_num_free_staging_blocks() == 18 - assert manager._blocks_for_save[req2] == 12 - - # req2 and req3 finish - assert manager.free(req2, "save") == 12 - assert manager.free(req3, "load") == 20 - assert manager.get_num_free_staging_blocks() == 50 - assert manager.get_num_used_staging_blocks() == 0 - - -class TestScheduler: - - def _add_prompt_to_scheduler_cpu_backend(self, scheduler, prompt_tokens): - """ add """ - keys_gen = scheduler.token_processor.process_tokens(prompt_tokens) - keys = list(keys_gen) - for i in range(len(keys)): - start, end, key = keys[i] - scheduler.cpu_backend.add(key, "dummy_data") - - def test_get_num_new_matched_tokens_no_hit(self, scheduler_factory, - clean_backend_instance): - """ - Tests that get_num_new_matched_tokens returns 0 when there is no - matching prefix in the CPU cache. - """ - scheduler = scheduler_factory() - assert len(scheduler.cpu_backend.cache) == 0 - request = create_request("req1", - list(range(scheduler.block_size * 2)), - block_size=scheduler.block_size) - num_matched, _ = scheduler.get_num_new_matched_tokens(request, 0) - assert num_matched == 0 - assert request.request_id not in scheduler.load_specs - - @pytest.mark.parametrize( - "num_computed_blocks, num_matched_blocks, num_prompt_blocks", - [(0, 3, 4), (1, 3, 4), (3, 3, 4)], - ) - def test_get_num_new_matched_tokens_partial_hit(self, scheduler_factory, - clean_backend_instance, - num_computed_blocks, - num_matched_blocks, - num_prompt_blocks): - """ - Tests that get_num_new_matched_tokens correctly identifies a partial - prefix hit and creates a LoadSpec. - """ - - scheduler = scheduler_factory() - assert len(scheduler.cpu_backend.cache) == 0 - num_computed_tokens = num_computed_blocks * scheduler.block_size - num_matched_tokens = num_matched_blocks * scheduler.block_size - num_prompt_tokens = num_prompt_blocks * scheduler.block_size - - prompt_tokens = list(range(num_prompt_tokens)) - request = create_request("req1", - prompt_tokens, - block_size=scheduler.block_size) - - # Simulate a cache hit for the first 3 block - self._add_prompt_to_scheduler_cpu_backend( - scheduler, prompt_tokens[:num_matched_tokens]) - - num_tokens_to_load, _ = scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) - - assert num_tokens_to_load == num_matched_tokens - num_computed_tokens - if num_tokens_to_load > 0: - assert request.request_id in scheduler.load_specs - load_spec = scheduler.load_specs[request.request_id] - assert load_spec.num_matched_tokens == num_matched_tokens - assert load_spec.num_skip_leading_tokens == num_computed_tokens - assert not load_spec.can_load - - @pytest.mark.parametrize( - "num_computed_blocks, num_prompt_blocks", - [(0, 4), (3, 4), (4, 4)], - ) - def test_get_num_new_matched_tokens_full_hit(self, scheduler_factory, - clean_backend_instance, - num_computed_blocks, - num_prompt_blocks): - """ - Tests the special case of a full prefix hit, where N-1 tokens are - reported to the vLLM scheduler. - """ - scheduler = scheduler_factory() - assert len(scheduler.cpu_backend.cache) == 0 - - num_computed_tokens = num_computed_blocks * scheduler.block_size - num_prompt_tokens = num_prompt_blocks * scheduler.block_size - num_matched_tokens = num_prompt_tokens - - prompt_tokens = list(range(num_prompt_tokens)) - request = create_request("req1", - prompt_tokens, - block_size=scheduler.block_size) - - # Simulate a cache hit for the entire prompt - self._add_prompt_to_scheduler_cpu_backend(scheduler, prompt_tokens) - - num_tokens_to_load, _ = scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) - - # Should report N-1 to scheduler, but LoadSpec should have the full N - assert num_tokens_to_load == max( - 0, num_matched_tokens - num_computed_tokens - 1) - if num_matched_tokens > num_computed_tokens: - assert request.request_id in scheduler.load_specs - load_spec = scheduler.load_specs[request.request_id] - assert load_spec.num_matched_tokens == num_matched_tokens - assert load_spec.num_skip_leading_tokens == num_computed_tokens - assert not load_spec.can_load - - @pytest.mark.parametrize( - "num_computed_blocks, num_prompt_blocks, num_staging_blocks", - [(0, 4, 0), (0, 4, 2), (2, 4, 1)], - ) - def test_get_num_new_matched_tokens_hit_with_limited_staging_buffer( - self, scheduler_factory, clean_backend_instance, - num_computed_blocks, num_prompt_blocks, num_staging_blocks): - """ - Tests the special case of a full prefix hit, where N-1 tokens are - reported to the vLLM scheduler. - """ - num_staging_tokens = num_staging_blocks * _DEFAULT_BLOCK_SIZE - scheduler = scheduler_factory( - offload_staging_buffer_tokens=num_staging_tokens) - assert len(scheduler.cpu_backend.cache) == 0 - - num_computed_tokens = num_computed_blocks * scheduler.block_size - num_prompt_tokens = num_prompt_blocks * scheduler.block_size - num_matched_tokens = num_prompt_tokens - - prompt_tokens = list(range(num_prompt_tokens)) - request = create_request("req1", - prompt_tokens, - block_size=scheduler.block_size) - - # Simulate a cache hit for the entire prompt - self._add_prompt_to_scheduler_cpu_backend(scheduler, prompt_tokens) - - num_tokens_to_load, _ = scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) - - gt_num_tokens_to_load = min(num_matched_tokens - num_computed_tokens, - num_staging_tokens) - gt_num_matched_tokens = gt_num_tokens_to_load + num_computed_tokens - - # Should report N-1 to scheduler, but LoadSpec should have the full N - if gt_num_matched_tokens == num_prompt_tokens: - assert num_tokens_to_load == gt_num_tokens_to_load - 1 - - if gt_num_matched_tokens > num_computed_tokens: - assert request.request_id in scheduler.load_specs - load_spec = scheduler.load_specs[request.request_id] - assert load_spec.num_matched_tokens == gt_num_matched_tokens - assert load_spec.num_skip_leading_tokens == num_computed_tokens - assert len(load_spec.dst_blocks - ) == gt_num_tokens_to_load // scheduler.block_size - assert not load_spec.can_load - - @pytest.mark.parametrize( - "num_skip_leading_tokens, num_matched_tokens, save_behavior, dynamic_pad_lower_limit", - [(0, _DEFAULT_BLOCK_SIZE * 4, "drop", 0), - (0, _DEFAULT_BLOCK_SIZE * 4, "pad", 0), - (0, _DEFAULT_BLOCK_SIZE * 4, "dynamic", 1), - (0, _DEFAULT_BLOCK_SIZE * 4, "dynamic", _DEFAULT_BLOCK_SIZE - 1), - (_DEFAULT_BLOCK_SIZE, _DEFAULT_BLOCK_SIZE * 4 + 2, "drop", 0), - (_DEFAULT_BLOCK_SIZE, _DEFAULT_BLOCK_SIZE * 4 + 2, "pad", 0), - (_DEFAULT_BLOCK_SIZE, _DEFAULT_BLOCK_SIZE * 4 + 1, "dynamic", 1), - (_DEFAULT_BLOCK_SIZE, _DEFAULT_BLOCK_SIZE * 4 + 2, "dynamic", 1), - (_DEFAULT_BLOCK_SIZE, _DEFAULT_BLOCK_SIZE * 5 - 1, "dynamic", - _DEFAULT_BLOCK_SIZE - 1)], - ) - def test_update_state_after_alloc(self, scheduler_factory, - clean_backend_instance, - num_skip_leading_tokens, - num_matched_tokens, save_behavior, - dynamic_pad_lower_limit): - """ - Tests that update_state_after_alloc correctly updates the LoadSpec - when blocks are allocated for a request with a cache hit. - """ - scheduler = scheduler_factory( - offload_partial_block_save_behavior=save_behavior, - offload_partial_block_dynamic_pad_lower_limit= - dynamic_pad_lower_limit) - assert len(scheduler.cpu_backend.cache) == 0 - - # The ground truth of loading decisions - num_partial_block_tokens = num_matched_tokens % scheduler.block_size - if save_behavior == "drop" or \ - (save_behavior == "dynamic" - and num_partial_block_tokens < dynamic_pad_lower_limit): - # the last partial blocks needs to be dropped - num_matched_tokens -= num_partial_block_tokens - - num_external_tokens = num_matched_tokens - num_skip_leading_tokens - num_blocks_to_skip = num_skip_leading_tokens // scheduler.block_size - num_blocks_to_load = (num_external_tokens + scheduler.block_size - - 1) // scheduler.block_size - - prompt_tokens = list(range(num_matched_tokens)) - request = create_request("req1", - prompt_tokens, - block_size=scheduler.block_size, - num_computed_tokens=num_skip_leading_tokens) - - # Setup a pending load operation - scheduler.load_specs[request.request_id] = MagicMock( - num_matched_tokens=num_matched_tokens, - num_skip_leading_tokens=num_skip_leading_tokens, - dst_blocks=[-1] * num_blocks_to_load, - can_load=False) - - # Mock allocated blocks - allocated_blocks = MagicMock(spec=KVCacheBlocks) - num_blocks = (num_matched_tokens + scheduler.block_size - - 1) // scheduler.block_size - allocated_block_ids = [i for i in range(num_blocks)] - allocated_blocks.get_block_ids.return_value = [allocated_block_ids] - - scheduler.update_state_after_alloc(request, allocated_blocks, - num_external_tokens) - - load_spec = scheduler.load_specs[request.request_id] - assert load_spec.can_load - assert len(load_spec.dst_blocks) == num_blocks_to_load - assert load_spec.dst_blocks == allocated_block_ids[num_blocks_to_skip:( - num_blocks_to_load + num_blocks_to_skip)] - - @pytest.mark.parametrize( - "save_behavior, dynamic_pad_lower_limit, prompt_len, num_computed_tokens", - [("drop", 0, _DEFAULT_BLOCK_SIZE * 4 + 2, 0), - ("pad", 0, _DEFAULT_BLOCK_SIZE * 4 + 2, _DEFAULT_BLOCK_SIZE), - ("dynamic", 1, _DEFAULT_BLOCK_SIZE * 4 + 2, 0), - ("dynamic", _DEFAULT_BLOCK_SIZE - 1, _DEFAULT_BLOCK_SIZE * 4 + 2, - _DEFAULT_BLOCK_SIZE)]) - def test_build_connector_meta_new_request(self, scheduler_factory, - clean_backend_instance, - save_behavior, - dynamic_pad_lower_limit, - prompt_len, num_computed_tokens): - """ - Tests metadata generation for a new request (prefill) that has no - cache hit and generates enough tokens to trigger a save. - - NOTE(jcgu): - 1. we will not cover load + save for new_request here, since load - is determined by `get_num_new_matched_tokens()` - - """ - scheduler = scheduler_factory( - offload_partial_block_save_behavior=save_behavior, - offload_partial_block_dynamic_pad_lower_limit= - dynamic_pad_lower_limit, - offload_staging_buffer_tokens=2 * prompt_len) - assert len(scheduler.cpu_backend.cache) == 0 - - prompt_tokens = list(range(prompt_len)) - request = create_request("req1", - prompt_tokens, - block_size=scheduler.block_size, - num_computed_tokens=num_computed_tokens) - num_blocks = (prompt_len + scheduler.block_size - - 1) // scheduler.block_size - request.block_ids = [[i for i in range(num_blocks)]] - new_scheduled_tokens = prompt_len - num_computed_tokens - - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[request], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={"req1": new_scheduled_tokens}, - total_num_scheduled_tokens=new_scheduled_tokens, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={}, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None, - ) - - metadata = scheduler.build_connector_meta(scheduler_output) - - # ground_truth - num_tokens_in_partial_block = prompt_len % scheduler.block_size - num_processed_tokens = prompt_len - if save_behavior == "drop" or (save_behavior == "dynamic" - and num_tokens_in_partial_block - < dynamic_pad_lower_limit): - num_processed_tokens = ( - prompt_len // scheduler.block_size) * scheduler.block_size - num_skip_blocks = num_computed_tokens // scheduler.block_size - num_blocks_to_save = (num_processed_tokens + scheduler.block_size - - 1) // scheduler.block_size - num_skip_blocks - - assert len(metadata.requests_meta) == 1 - req_meta = metadata.requests_meta[0] - assert req_meta.req_id == "req1" - assert req_meta.load_spec is None - assert req_meta.save_spec is not None - assert req_meta.save_spec.num_total_tokens == num_processed_tokens - assert req_meta.save_spec.num_skip_leading_tokens == num_computed_tokens - assert len(req_meta.save_spec.src_blocks) == num_blocks_to_save - assert req_meta.save_spec.src_blocks == request.block_ids[0][ - num_skip_blocks:(num_skip_blocks + num_blocks_to_save)] - assert not req_meta.save_spec.is_final_save - - tracker = scheduler._request_trackers["req1"] - assert tracker.save_watermark == num_processed_tokens - assert tracker.block_ids == request.block_ids[0] - - @pytest.mark.parametrize( - "save_behavior, dynamic_pad_lower_limit, prompt_len, num_computed_tokens, num_staging_blocks", - [ - ("drop", 0, _DEFAULT_BLOCK_SIZE * 4 + 2, 0, 0), - ("drop", 0, _DEFAULT_BLOCK_SIZE * 4 + 2, 0, 2), - ]) - def test_build_connector_meta_new_request_with_limited_staging_buffer( - self, scheduler_factory, clean_backend_instance, save_behavior, - dynamic_pad_lower_limit, prompt_len, num_computed_tokens, - num_staging_blocks): - """ - get a new request, but limited staging buffer. - """ - num_staging_buffer_tokens = num_staging_blocks * _DEFAULT_BLOCK_SIZE - - scheduler = scheduler_factory( - offload_partial_block_save_behavior=save_behavior, - offload_partial_block_dynamic_pad_lower_limit= - dynamic_pad_lower_limit, - offload_staging_buffer_tokens=num_staging_buffer_tokens) - assert len(scheduler.cpu_backend.cache) == 0 - - prompt_tokens = list(range(prompt_len)) - request = create_request("req1", - prompt_tokens, - block_size=scheduler.block_size, - num_computed_tokens=num_computed_tokens) - num_blocks = (prompt_len + scheduler.block_size - - 1) // scheduler.block_size - request.block_ids = [[i for i in range(num_blocks)]] - new_scheduled_tokens = prompt_len - num_computed_tokens - - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[request], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={"req1": new_scheduled_tokens}, - total_num_scheduled_tokens=new_scheduled_tokens, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={}, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None, - ) - - metadata = scheduler.build_connector_meta(scheduler_output) - - # ground_truth - num_tokens_in_partial_block = prompt_len % scheduler.block_size - num_processed_tokens = prompt_len - if save_behavior == "drop" or (save_behavior == "dynamic" - and num_tokens_in_partial_block - < dynamic_pad_lower_limit): - num_processed_tokens = ( - prompt_len // scheduler.block_size) * scheduler.block_size - num_skip_blocks = num_computed_tokens // scheduler.block_size - num_blocks_to_save = (num_processed_tokens + scheduler.block_size - - 1) // scheduler.block_size - num_skip_blocks - - # throttled by the limited staging buffer - if num_blocks_to_save > num_staging_blocks: - num_blocks_to_save = num_staging_blocks - num_processed_tokens = num_staging_buffer_tokens + num_computed_tokens - - if num_blocks_to_save > 0: - assert len(metadata.requests_meta) == 1 - req_meta = metadata.requests_meta[0] - assert req_meta.req_id == "req1" - assert req_meta.load_spec is None - assert req_meta.save_spec is not None - assert req_meta.save_spec.num_total_tokens == num_processed_tokens - assert req_meta.save_spec.num_skip_leading_tokens == num_computed_tokens - assert len(req_meta.save_spec.src_blocks) == num_blocks_to_save - assert req_meta.save_spec.src_blocks == request.block_ids[0][ - num_skip_blocks:(num_skip_blocks + num_blocks_to_save)] - assert not req_meta.save_spec.is_final_save - - tracker = scheduler._request_trackers["req1"] - assert tracker.save_watermark == num_processed_tokens - assert tracker.block_ids == request.block_ids[0] - - @pytest.mark.parametrize( - "save_behavior, dynamic_pad_lower_limit, decode_save, prompt_len", - [("drop", 0, 0, _DEFAULT_BLOCK_SIZE * 4 + 2), - ("drop", 0, 1, _DEFAULT_BLOCK_SIZE * 4 + 3), - ("pad", 0, 0, _DEFAULT_BLOCK_SIZE * 4 + 2), - ("pad", 0, 1, _DEFAULT_BLOCK_SIZE * 4 + 3), - ("dynamic", 1, 0, _DEFAULT_BLOCK_SIZE * 4 + 2), - ("dynamic", _DEFAULT_BLOCK_SIZE - 1, 1, _DEFAULT_BLOCK_SIZE * 4 + 3)]) - def test_build_connector_meta_cached_request_with_one_decode( - self, scheduler_factory, clean_backend_instance, save_behavior, - dynamic_pad_lower_limit, decode_save, prompt_len): - """ - Tests metadata generation for a running request (chunked prefill) - that gets more tokens scheduled, styled as a single unit test. - """ - scheduler = scheduler_factory( - offload_decode_save=decode_save, - offload_partial_block_save_behavior=save_behavior, - offload_partial_block_dynamic_pad_lower_limit= - dynamic_pad_lower_limit) - assert len(scheduler.cpu_backend.cache) == 0 - - gen_len = 1 # single decode step - num_total_tokens = prompt_len + gen_len - request_tokens = list(range(num_total_tokens)) - num_prompt_blocks = (prompt_len + scheduler.block_size - - 1) // scheduler.block_size - num_total_blocks = (num_total_tokens + scheduler.block_size - - 1) // scheduler.block_size - request = create_request("req1", - request_tokens, - block_size=scheduler.block_size, - num_computed_tokens=prompt_len) - request.block_ids = [[i for i in range(num_total_blocks)]] - - # Arrange: Set up the scheduler's state to simulate a request that has - # already been partially processed. - initial_tokens = request_tokens[:prompt_len] - initial_block_ids = [i for i in range(num_prompt_blocks)] - - initial_save_watermark = prompt_len - num_tokens_in_partial_block = prompt_len % scheduler.block_size - if save_behavior == "drop" or (save_behavior == "dynamic" - and num_tokens_in_partial_block - < dynamic_pad_lower_limit): - initial_save_watermark = ( - prompt_len // scheduler.block_size) * scheduler.block_size - - tracker = RequestTracker( - req_id="req1", - prompt_len=prompt_len, - token_ids=initial_tokens, - block_ids=initial_block_ids, - save_watermark=initial_save_watermark, - is_decode_phase=False, - ) - scheduler._request_trackers["req1"] = tracker - scheduler._unfinished_requests["req1"] = request - - # Act: Simulate a decode step - new_blocks_ids = [ - i for i in range(num_prompt_blocks, num_total_blocks) - ] - logger.info(f"new_blocks_ids: {new_blocks_ids}") - cached_req_data = CachedRequestData.make_empty() - cached_req_data.req_ids = ["req1"] - cached_req_data.new_block_ids = (new_blocks_ids, ) - - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=cached_req_data, - num_scheduled_tokens={"req1": gen_len}, - total_num_scheduled_tokens=gen_len, - finished_req_ids=set(), - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={}, - num_common_prefix_blocks=0, - free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None, - ) - - metadata = scheduler.build_connector_meta(scheduler_output) - - # The tracker should be updated with the new tokens and blocks. - updated_tracker = scheduler._request_trackers["req1"] - assert updated_tracker.token_ids == request_tokens - assert updated_tracker.block_ids == request.block_ids[0] - assert updated_tracker.is_decode_phase - - # ground-truth of save - if not decode_save: - assert updated_tracker.save_watermark == initial_save_watermark - else: - # NOTE(jcgu): currently still mimic the internal logic, - # find a better way to - next_save_boundary = ( - initial_save_watermark // scheduler.block_size + - 1) * scheduler.block_size - if num_total_tokens < next_save_boundary: - # nothing to save - assert updated_tracker.save_watermark == initial_save_watermark - else: - # Assert: Verify the generated metadata and the updated tracker state. - assert len(metadata.requests_meta) == 1 - req_meta = metadata.requests_meta[0] - assert req_meta.req_id == "req1" - assert req_meta.load_spec is None - - # a block (maybe part of its tokens has been saved) should be saved. - assert req_meta.save_spec.num_total_tokens == num_total_tokens - assert req_meta.save_spec.num_skip_leading_tokens == num_total_tokens - scheduler.block_size - assert req_meta.save_spec.src_blocks == [ - request.block_ids[0][-1] - ] - assert not req_meta.save_spec.is_final_save - - assert updated_tracker.save_watermark == num_total_tokens - - @pytest.mark.parametrize( - "save_behavior, dynamic_pad_lower_limit, decode_save, prompt_len, gen_len", - [("drop", 0, 0, _DEFAULT_BLOCK_SIZE * 4 + 2, 3), - ("pad", 0, 1, _DEFAULT_BLOCK_SIZE * 4 + 2, 1), - ("pad", 0, 1, _DEFAULT_BLOCK_SIZE * 4 + 2, 4), - ("dynamic", 1, 1, _DEFAULT_BLOCK_SIZE * 4 + 2, 1), - ("dynamic", _DEFAULT_BLOCK_SIZE - 1, 1, _DEFAULT_BLOCK_SIZE * 4 + 2, - 4)]) - def test_build_connector_meta_finished_request( - self, scheduler_factory, clean_backend_instance, save_behavior, - dynamic_pad_lower_limit, decode_save, prompt_len, gen_len): - """ - Tests metadata generation for a finishing request. - """ - scheduler = scheduler_factory( - offload_decode_save=decode_save, - offload_partial_block_save_behavior=save_behavior, - offload_partial_block_dynamic_pad_lower_limit= - dynamic_pad_lower_limit) - assert len(scheduler.cpu_backend.cache) == 0 - - num_total_tokens = prompt_len + gen_len - request_tokens = list(range(num_total_tokens)) - num_total_blocks = (num_total_tokens + scheduler.block_size - - 1) // scheduler.block_size - request = create_request("req1", - request_tokens, - block_size=scheduler.block_size, - num_computed_tokens=prompt_len) - request.block_ids = [[i for i in range(num_total_blocks)]] - - # Arrange: Set up the scheduler's state to simulate a request that has - # already been processed. - num_tokens_in_partial_block = prompt_len % scheduler.block_size - - adjusted_prompt_len = prompt_len - if save_behavior == "drop" or (save_behavior == "dynamic" - and num_tokens_in_partial_block - < dynamic_pad_lower_limit): - adjusted_prompt_len = (prompt_len // - scheduler.block_size) * scheduler.block_size - - latest_save_watermark = adjusted_prompt_len - if decode_save: - num_full_block_tokens = num_total_tokens // scheduler.block_size * scheduler.block_size - latest_save_watermark = max(num_full_block_tokens, - adjusted_prompt_len) - logger.info( - f"latest_save_watermark: {latest_save_watermark}, {num_full_block_tokens}, {adjusted_prompt_len}" - ) - - tracker = RequestTracker( - req_id="req1", - prompt_len=prompt_len, - token_ids=request_tokens, - block_ids=request.block_ids[0], - save_watermark=latest_save_watermark, - is_decode_phase=True, - ) - scheduler._request_trackers["req1"] = tracker - scheduler._unfinished_requests["req1"] = request - - finished_req_ids = set() - finished_req_ids.add("req1") - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={}, - total_num_scheduled_tokens=0, - finished_req_ids=finished_req_ids, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={}, - num_common_prefix_blocks=0, - free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None, - ) - - metadata = scheduler.build_connector_meta(scheduler_output) - req_meta = metadata.requests_meta[0] - assert req_meta.load_spec is None - - # ground-truth of save - if not decode_save: - assert req_meta.save_spec.is_final_save - assert req_meta.save_spec.skip_save - assert req_meta.save_spec.src_blocks == [] - else: - # since it's a finished request, tokens are saved until the last full block (thanks to decode_save's next_block_boundary) - num_tokens_in_last_partial_block = num_total_tokens % scheduler.block_size - if save_behavior == "drop" or (save_behavior == "dynamic" - and num_tokens_in_last_partial_block - < dynamic_pad_lower_limit): - # if drop, then no blocks to save - assert req_meta.save_spec.is_final_save - assert req_meta.save_spec.skip_save - assert req_meta.save_spec.src_blocks == [] - else: - # otherwise, save - num_skip_leading_blocks = tracker.save_watermark // scheduler.block_size - num_skip_leading_tokens = num_skip_leading_blocks * scheduler.block_size - assert req_meta.save_spec.num_total_tokens == num_total_tokens - assert req_meta.save_spec.num_skip_leading_tokens == num_skip_leading_tokens - assert req_meta.save_spec.src_blocks == request.block_ids[0][ - num_skip_leading_blocks:] - assert req_meta.save_spec.is_final_save - assert not req_meta.save_spec.skip_save diff --git a/tests/distributed/offload/tpu_offload_manager_test.py b/tests/distributed/offload/tpu_offload_manager_test.py new file mode 100644 index 000000000..7d63e3cbc --- /dev/null +++ b/tests/distributed/offload/tpu_offload_manager_test.py @@ -0,0 +1,357 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from tpu_inference.distributed.offload.offload_manager import ( + CPUChunkPool, LRUCacheManager, StagingBufferManager) +from tpu_inference.distributed.offload.utils import ReqId +from tpu_inference.logger import init_logger + +logger = init_logger(__name__) + + +class TestStagingBufferManager: + + def test_initialization(self): + manager = StagingBufferManager(num_blocks=100) + assert manager.num_blocks == 100 + assert manager.get_num_free_staging_blocks() == 100 + assert manager.get_num_used_staging_blocks() == 0 + + def test_allocate_simple(self): + manager = StagingBufferManager(num_blocks=100) + req_id1: ReqId = "req1" + req_id2: ReqId = "req2" + + allocated1 = manager.allocate(req_id1, 10, "load") + assert allocated1 == 10 + assert manager.get_num_free_staging_blocks() == 90 + assert manager.get_num_used_staging_blocks() == 10 + assert manager._num_blocks_for_load == 10 + assert manager._num_blocks_for_save == 0 + + allocated2 = manager.allocate(req_id2, 20, "save") + assert allocated2 == 20 + assert manager.get_num_free_staging_blocks() == 70 + assert manager.get_num_used_staging_blocks() == 30 + assert manager._num_blocks_for_load == 10 + assert manager._num_blocks_for_save == 20 + + def test_allocate_insufficient_capacity(self): + manager = StagingBufferManager(num_blocks=10) + req_id: ReqId = "req1" + allocated = manager.allocate(req_id, 20, "load") + assert allocated == 0 + assert manager.get_num_free_staging_blocks() == 10 + assert manager.get_num_used_staging_blocks() == 0 + + def test_allocate_existing_load_request(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "load") + with pytest.raises(ValueError): + # multiple concurrent loads from a single request is not allowed. + manager.allocate(req_id, 5, "load") + + def test_allocate_existing_save_request(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "save") + assert manager._blocks_for_save[req_id] == 10 + manager.allocate(req_id, 5, "save") + assert manager._blocks_for_save[req_id] == 15 + assert manager.get_num_free_staging_blocks() == 85 + assert manager.get_num_used_staging_blocks() == 15 + + def test_allocate_negative_blocks(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + allocated = manager.allocate(req_id, -5, "load") + assert allocated == -5 + assert manager.get_num_free_staging_blocks() == 100 + + def test_free_full(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "load") + freed = manager.free(req_id, "load") + assert freed == 10 + assert manager.get_num_free_staging_blocks() == 100 + assert manager.get_num_used_staging_blocks() == 0 + assert req_id not in manager._blocks_for_load + + def test_free_partial(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "save") + freed = manager.free(req_id, "save", num_finished_blocks=4) + assert freed == 4 + assert manager.get_num_free_staging_blocks() == 94 + assert manager.get_num_used_staging_blocks() == 6 + assert manager._blocks_for_save[req_id] == 6 + + def test_free_more_than_allocated(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "load") + manager.free(req_id, "load", num_finished_blocks=15) + assert req_id not in manager._blocks_for_load + + def test_free_non_existent_request(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + freed = manager.free(req_id, "load") + assert freed == 0 + + def test_get_usage(self): + manager = StagingBufferManager(num_blocks=100) + req_id1: ReqId = "req1" + req_id2: ReqId = "req2" + manager.allocate(req_id1, 10, "load") + manager.allocate(req_id2, 20, "save") + + usage_str = manager.get_usage() + expected_str = "Staging Buffer: total=100, free=70, used_for_load=10, used_for_save=20;" + assert usage_str == expected_str + + usage_str_details = manager.get_usage(with_details=True) + assert "save_details:{req2:20,}" in usage_str_details + assert "load_details:{req1:10,}" in usage_str_details + + def test_complex_scenario(self): + manager = StagingBufferManager(num_blocks=50) + req1, req2, req3 = "req1", "req2", "req3" + + # req1 loads 10, req2 saves 15 + assert manager.allocate(req1, 10, "load") == 10 + assert manager.allocate(req2, 15, "save") == 15 + assert manager.get_num_free_staging_blocks() == 25 + assert manager.get_num_used_staging_blocks() == 25 + + # req3 tries to load 30, fails + assert manager.allocate(req3, 30, "load") == 0 + assert manager.get_num_free_staging_blocks() == 25 + + # req1 finishes loading + assert manager.free(req1, "load") == 10 + assert manager.get_num_free_staging_blocks() == 35 + + # req3 can now load 20 + assert manager.allocate(req3, 20, "load") == 20 + assert manager.get_num_free_staging_blocks() == 15 + assert manager.get_num_used_staging_blocks( + ) == 35 # 15 for save (req2) + 20 for load (req3) + + # req2 saves another 5 + assert manager.allocate(req2, 5, "save") == 5 + assert manager.get_num_free_staging_blocks() == 10 + assert manager._blocks_for_save[req2] == 20 + + # req2 frees 8 blocks + assert manager.free(req2, "save", 8) == 8 + assert manager.get_num_free_staging_blocks() == 18 + assert manager._blocks_for_save[req2] == 12 + + # req2 and req3 finish + assert manager.free(req2, "save") == 12 + assert manager.free(req3, "load") == 20 + assert manager.get_num_free_staging_blocks() == 50 + assert manager.get_num_used_staging_blocks() == 0 + + +class TestCPUChunkPool: + + def test_initialization(self): + pool = CPUChunkPool(num_chunks=10) + assert pool.num_chunks == 10 + assert pool.num_free_chunks == 10 + assert pool.num_allocated_chunks == 0 + assert len(pool.free_chunk_list) == 10 + + def test_allocate_chunks(self): + pool = CPUChunkPool(num_chunks=10) + chunk_hashes = [101, 102, 103] + chunks = pool.allocate_chunks(chunk_hashes) + + assert len(chunks) == 3 + assert pool.num_free_chunks == 7 + assert pool.num_allocated_chunks == 3 + for i, chunk in enumerate(chunks): + assert chunk.chunk_hash == chunk_hashes[i] + assert chunk.chunk_id in pool.allocated_id_to_hash_map + + def test_allocate_chunks_insufficient_space(self): + pool = CPUChunkPool(num_chunks=2) + chunk_hashes = [101, 102, 103] + with pytest.raises(ValueError): + pool.allocate_chunks(chunk_hashes) + + def test_release_chunks(self): + pool = CPUChunkPool(num_chunks=10) + chunk_hashes = [101, 102, 103] + chunks = pool.allocate_chunks(chunk_hashes) + for chunk in chunks: + chunk.touch() + + for chunk in chunks: + pool.release_chunk(chunk) + + assert pool.num_free_chunks == 10 + assert pool.num_allocated_chunks == 0 + assert len(pool.free_chunk_list) == 10 + for chunk in chunks: + assert chunk.chunk_id not in pool.allocated_id_to_hash_map + assert chunk.chunk_hash is None + assert chunk.ref_cnt == -1 + + def test_release_chunks_in_use(self): + pool = CPUChunkPool(num_chunks=10) + chunk_hashes = [101] + chunks = pool.allocate_chunks(chunk_hashes) + chunks[0].touch() # ref_cnt = 0: saved + chunks[0].touch() # ref_cnt = 1: loading + + assert not pool.release_chunk(chunks[0]) + + +class TestLRUCacheManager: + + def test_initialization(self): + manager = LRUCacheManager(num_cpu_chunks=20) + assert manager.num_chunks == 20 + assert isinstance(manager.chunk_pool, CPUChunkPool) + assert len(manager.cpu_cache) == 0 + + def test_lookup(self): + manager = LRUCacheManager(num_cpu_chunks=20) + chunk_hashes = [101, 102, 103] + + # 1. Cache miss + assert manager.lookup(chunk_hashes) == 0 + + # 2. Cache hit + # Manually add to cache for testing + chunks = manager.chunk_pool.allocate_chunks(chunk_hashes) + for chunk, h in zip(chunks, chunk_hashes): + chunk.touch() # Make it ready to load + manager.cpu_cache[h] = chunk + + assert manager.lookup(chunk_hashes) == 3 + + # 3. Partial hit + assert manager.lookup([101, 102, 104]) == 2 + + def test_touch(self): + manager = LRUCacheManager(num_cpu_chunks=3) + chunk_hashes = [101, 102, 103] + chunks = manager.chunk_pool.allocate_chunks(chunk_hashes) + for chunk, h in zip(chunks, chunk_hashes): + manager.cpu_cache[h] = chunk + + manager.touch([101]) + assert list(manager.cpu_cache.keys()) == [102, 103, 101] + + manager.touch([102, 103]) + assert list(manager.cpu_cache.keys()) == [101, 103, 102] + + def test_allocate_for_save_simple(self): + manager = LRUCacheManager(num_cpu_chunks=5) + chunk_hashes = [101, 102] + + new_chunks, new_chunk_idxs = manager.allocate_for_save(chunk_hashes) + + assert len(new_chunks) == 2 + assert new_chunk_idxs == [0, 1] + assert manager.chunk_pool.num_free_chunks == 3 + assert len(manager.cpu_cache) == 2 + + def test_allocate_for_save_no_new_chunks(self): + manager = LRUCacheManager(num_cpu_chunks=5) + chunk_hashes = [101, 102] + manager.allocate_for_save(chunk_hashes) + + result = manager.allocate_for_save(chunk_hashes) + assert result is None + + def test_allocate_for_save_with_eviction(self): + manager = LRUCacheManager(num_cpu_chunks=2) + # Fill the cache + manager.allocate_for_save([101, 102]) + # Mark as evictable + manager.cpu_cache[101].touch() + manager.cpu_cache[102].touch() + + manager.touch([101, 102]) + + # This should evict 102 + new_chunks, new_chunk_idxs = manager.allocate_for_save([103]) + + assert len(new_chunks) == 1 + assert new_chunk_idxs == [0] + assert 102 not in manager.cpu_cache + assert 101 in manager.cpu_cache + assert 103 in manager.cpu_cache + assert manager.chunk_pool.num_free_chunks == 0 + + def test_allocate_for_save_cannot_evict(self): + manager = LRUCacheManager(num_cpu_chunks=2) + manager.allocate_for_save([101, 102]) + # Mark as in use, not evictable + manager.cpu_cache[101].touch() + manager.cpu_cache[101].touch() + manager.cpu_cache[102].touch() + manager.cpu_cache[102].touch() + + result = manager.allocate_for_save([103]) + assert result is None + assert len(manager.cpu_cache) == 2 + + def test_prepare_load(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + manager.allocate_for_save(chunk_hashes) + manager.complete_save(chunk_hashes) # ref_cnt = 0 + + chunks = manager.prepare_load(chunk_hashes) + assert len(chunks) == 1 + assert chunks[0].is_in_use # ref_cnt = 1 + + def test_complete_save(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + manager.allocate_for_save(chunk_hashes) + + chunk = manager.cpu_cache[101] + assert not chunk.is_ready_to_load # ref_cnt = -1 + + manager.complete_save(chunk_hashes) + assert chunk.is_ready_to_load # ref_cnt = 0 + + def test_complete_load(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + manager.allocate_for_save(chunk_hashes) + manager.complete_save(chunk_hashes) + chunks = manager.prepare_load(chunk_hashes) + + assert chunks[0].is_in_use # ref_cnt = 1 + manager.complete_load(chunk_hashes) + assert not chunks[0].is_in_use # ref_cnt = 0 + + def test_mark_completion(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + new_chunks, _ = manager.allocate_for_save(chunk_hashes) + chunk_ids = [c.chunk_id for c in new_chunks] + + manager.mark_completion(chunk_ids, 'save') + assert manager.cpu_cache[101].is_ready_to_load + + manager.prepare_load(chunk_hashes) + assert manager.cpu_cache[101].is_in_use + manager.mark_completion(chunk_ids, 'load') + assert not manager.cpu_cache[101].is_in_use + + def test_mark_completion_unknown_id(self): + manager = LRUCacheManager(num_cpu_chunks=2) + with pytest.raises(ValueError): + manager.mark_completion([999], 'save') diff --git a/tpu_inference/distributed/offload/offload_manager.py b/tpu_inference/distributed/offload/offload_manager.py index cd7260f3d..166d0dadd 100644 --- a/tpu_inference/distributed/offload/offload_manager.py +++ b/tpu_inference/distributed/offload/offload_manager.py @@ -79,6 +79,7 @@ def allocate_chunks(self, chunk_hashes: list[ChunkHash]) -> list[CPUChunk]: ret: list[CPUChunk] = [ self.free_chunk_list.pop() for _ in range(num_required_chunks) ] + self._num_allocated_chunks += num_required_chunks for chunk, chunk_hash in zip(ret, chunk_hashes): chunk._chunk_hash = chunk_hash assert chunk.chunk_id not in self.allocated_id_to_hash_map @@ -86,15 +87,16 @@ def allocate_chunks(self, chunk_hashes: list[ChunkHash]) -> list[CPUChunk]: return ret - def release_chunks(self, chunks: list[CPUChunk]): - for chunk in chunks: - if not chunk.is_ready_to_evict: - logger.warning(f" Chunk[{chunk.chunk_id}] is still in use.") - assert chunk.chunk_id in self.allocated_id_to_hash_map - self.allocated_id_to_hash_map.pop(chunk.chunk_id) - self.free_chunk_list.append(chunk) - chunk.reset() - self._num_allocated_chunks -= len(chunks) + def release_chunk(self, chunk: CPUChunk) -> bool: + if not chunk.is_ready_to_evict: + logger.warning(f" Chunk[{chunk.chunk_id}] is still in use.") + return False + assert chunk.chunk_id in self.allocated_id_to_hash_map + self.allocated_id_to_hash_map.pop(chunk.chunk_id) + chunk.reset() + self.free_chunk_list.append(chunk) + self._num_allocated_chunks -= 1 + return True class LRUCacheManager: @@ -155,10 +157,10 @@ def allocate_for_save( return None # evict chunks - self.chunk_pool.release_chunks([ - self.cpu_cache.pop(evicting_chunk_hash) - for evicting_chunk_hash in to_evict - ]) + for evicting_chunk_hash in to_evict: + evicting_chunk = self.cpu_cache.pop(evicting_chunk_hash) + # always true, since all evicting chunks are ready to evict + self.chunk_pool.release_chunk(evicting_chunk) new_chunk_hashes = [chunk_hashes[i] for i in new_chunk_idxs] # allocate @@ -200,10 +202,14 @@ def complete_load(self, chunk_hashes: list[ChunkHash]) -> None: def mark_completion(self, chunk_ids, operation: Literal['save', 'load']) -> None: - chunk_hashes = [ - self.chunk_pool.allocated_id_to_hash_map[chunk_id] - for chunk_id in chunk_ids - ] + try: + chunk_hashes = [ + self.chunk_pool.allocated_id_to_hash_map[chunk_id] + for chunk_id in chunk_ids + ] + except Exception as e: + raise ValueError(f' failed to retrieve chunk hashes: {e}') + chunk_hashes = [] unknown_chunk_ids = [] for chunk_id in chunk_ids: From 3d580cbc6a3b4ab8fc5801062d5120faadbbe8ab Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 20 Nov 2025 06:39:43 +0000 Subject: [PATCH 14/19] offload manager tests2 Signed-off-by: Juncheng Gu --- .../offload/tpu_offload_manager_test.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/distributed/offload/tpu_offload_manager_test.py b/tests/distributed/offload/tpu_offload_manager_test.py index 7d63e3cbc..d58b4f113 100644 --- a/tests/distributed/offload/tpu_offload_manager_test.py +++ b/tests/distributed/offload/tpu_offload_manager_test.py @@ -102,21 +102,6 @@ def test_free_non_existent_request(self): freed = manager.free(req_id, "load") assert freed == 0 - def test_get_usage(self): - manager = StagingBufferManager(num_blocks=100) - req_id1: ReqId = "req1" - req_id2: ReqId = "req2" - manager.allocate(req_id1, 10, "load") - manager.allocate(req_id2, 20, "save") - - usage_str = manager.get_usage() - expected_str = "Staging Buffer: total=100, free=70, used_for_load=10, used_for_save=20;" - assert usage_str == expected_str - - usage_str_details = manager.get_usage(with_details=True) - assert "save_details:{req2:20,}" in usage_str_details - assert "load_details:{req1:10,}" in usage_str_details - def test_complex_scenario(self): manager = StagingBufferManager(num_blocks=50) req1, req2, req3 = "req1", "req2", "req3" From 3767929aec59019648105a809144829dc4c6e709 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 21 Nov 2025 00:42:07 +0000 Subject: [PATCH 15/19] scheduler test, not fully ready yet Signed-off-by: Juncheng Gu --- .../offload/tpu_offload_accuracy_test.py | 1 + .../tpu_offload_connector_scheduler_test.py | 342 ++++++++++++++++++ 2 files changed, 343 insertions(+) create mode 100644 tests/distributed/offload/tpu_offload_connector_scheduler_test.py diff --git a/tests/distributed/offload/tpu_offload_accuracy_test.py b/tests/distributed/offload/tpu_offload_accuracy_test.py index 9c7705316..3f9897443 100644 --- a/tests/distributed/offload/tpu_offload_accuracy_test.py +++ b/tests/distributed/offload/tpu_offload_accuracy_test.py @@ -77,6 +77,7 @@ def _test_kv_cache_cpu_offloading_accuracy( out_texts2, out_tokens2 = parse_outputs(outputs) time.sleep(1) + # TODO(jcgu): check some internal states to verify save and load operations. # output1 and output2 should be idential assert len(out_texts1) == len(out_texts2) assert len(out_tokens1) == len(out_tokens2) diff --git a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py new file mode 100644 index 000000000..5bf264b7b --- /dev/null +++ b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from unittest.mock import MagicMock + +import pytest +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput +from vllm.v1.request import Request + +from tpu_inference.distributed.offload.tpu_offload_connector import ( + DEFAULT_TPU_OFFLOAD_CPU_CHUNKS, TPUOffloadConnectorScheduler) + +_DEFAULT_BLOCK_SIZE = 16 + + +class MockVllmConfig: + + def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): + self.model_config = self.Model() + self.cache_config = self.Cache(block_size) + + class Model: + model = "test-model" + + class Cache: + + def __init__(self, block_size): + self.block_size = block_size + + +def create_request( + request_id: str, + prompt_token_ids: list[int], + block_size: int, + num_computed_tokens: int = 0, +) -> Request: + """Creates a mock vLLM request object.""" + req = MagicMock(spec=Request) + req.request_id = request_id + req.req_id = request_id # for NewRequestData + req.prompt_token_ids = prompt_token_ids + req.all_token_ids = prompt_token_ids + req.num_computed_tokens = num_computed_tokens + req.block_size = block_size + req.block_ids = [[]] + # Mock the block_hashes property to return a list of mock hashes + req.block_hashes = [ + f"hash_{i}".encode() + for i in range(len(prompt_token_ids) // block_size) + ] + return req + + +@pytest.fixture +def scheduler_factory(): + """Provides a factory function for Scheduler instances.""" + + def _scheduler( + block_size: int = _DEFAULT_BLOCK_SIZE, + offload_decode_save: int = 0, + offload_partial_block_save_behavior: str = "drop", + offload_partial_block_dynamic_pad_lower_limit: int = 0, + offload_staging_buffer_tokens: int = -1, + offload_num_cpu_chunks: int = DEFAULT_TPU_OFFLOAD_CPU_CHUNKS, + ): + # update config + vllm_config = MockVllmConfig(block_size=block_size) + os.environ["TPU_OFFLOAD_DECODE_SAVE"] = str(offload_decode_save) + os.environ[ + "TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR"] = offload_partial_block_save_behavior + os.environ["TPU_OFFLOAD_PARTIAL_BLOCK_DYNAMIC_PAD_LOWER_LIMIT"] = str( + offload_partial_block_dynamic_pad_lower_limit) + if offload_staging_buffer_tokens >= 0: + os.environ["TPU_OFFLOAD_STAGING_BUFFER_TOKENS"] = str( + offload_staging_buffer_tokens) + if offload_num_cpu_chunks > 0: + os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str( + offload_num_cpu_chunks) + + return TPUOffloadConnectorScheduler(vllm_config) + + return _scheduler + + +class TestTPUOffloadConnectorScheduler: + + def test_get_num_new_matched_tokens_no_hit(self, scheduler_factory): + """ + Tests that get_num_new_matched_tokens returns 0 for a cache miss. + """ + scheduler = scheduler_factory() + request = create_request("req1", [1] * 32, scheduler.block_size) + + num_matched, _ = scheduler.get_num_new_matched_tokens(request, 0) + assert num_matched == 0 + assert "req1" not in scheduler.load_specs + + @pytest.mark.parametrize( + "num_computed_blocks, num_matched_blocks, num_prompt_blocks, num_staging_blocks", + [(0, 2, 4, 10), (1, 2, 4, 10), (0, 4, 4, 10), (1, 4, 4, 10), + (1, 4, 4, 1), (1, 4, 4, 0)]) + def test_get_num_new_matched_tokens_hit(self, scheduler_factory, + num_computed_blocks, + num_matched_blocks, + num_prompt_blocks, + num_staging_blocks): + """ + Tests correct identification of a prefix hit (partial and full). + test cases: + 1. no-skip + load 2 blocks + no staging buffer limit + 2. skip 1 block + load 1 block + no staging buffer limit + 3. no-skip + full-hit + no staging buffer limit + 4. skip 1 block + full-hit + no staging buffer limit + 5. skip 1 block + full-hit + only 1 staging block + 6. skip 1 block + full-hit + no staging block + """ + num_staging_tokens = num_staging_blocks * _DEFAULT_BLOCK_SIZE + scheduler = scheduler_factory( + offload_staging_buffer_tokens=num_staging_tokens) + prompt_len = scheduler.block_size * num_prompt_blocks + num_computed_tokens = scheduler.block_size * num_computed_blocks + num_blocks_to_load = num_matched_blocks - num_computed_blocks + # consider the case of limited staging blocks + num_blocks_to_load = min(num_blocks_to_load, num_staging_blocks) + num_matched_blocks = num_blocks_to_load + num_computed_blocks + num_matched_tokens = num_matched_blocks * scheduler.block_size + + request = create_request("req1", list(range(prompt_len)), + scheduler.block_size) + + # init offload_manager state + matched_block_hashes = request.block_hashes[:num_matched_blocks] + allocated_chunks, _ = scheduler.offload_manager.allocate_for_save( + matched_block_hashes) + scheduler.offload_manager.complete_save(matched_block_hashes) + + # call fn + num_external_matched_tokens, _ = scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + # check external_matched_tokens + if num_matched_blocks == num_prompt_blocks: + assert num_external_matched_tokens == num_blocks_to_load * scheduler.block_size - 1 + else: + assert num_external_matched_tokens == num_blocks_to_load * scheduler.block_size + + # check scheduler internal states + if num_blocks_to_load > 0: + # load_spec + assert "req1" in scheduler.load_specs + load_spec = scheduler.load_specs["req1"] + assert load_spec.num_matched_tokens == num_matched_tokens + assert not load_spec.can_load + allocated_chunk_ids = [ + chunk.chunk_id for chunk in allocated_chunks + ] + load_src_chunk_ids = allocated_chunk_ids[num_computed_blocks:] + assert load_spec.src_chunks == load_src_chunk_ids + assert load_spec.num_skip_leading_tokens == num_computed_tokens + assert len(load_spec.dst_blocks) == num_blocks_to_load + # cache_hits + assert "req1" in scheduler._external_cache_hits + assert scheduler._external_cache_hits["req1"] == num_matched_tokens + # staging_buffer + assert "req1" in scheduler.staging_buffer_manager._blocks_for_load + assert scheduler.staging_buffer_manager._blocks_for_load[ + "req1"] == num_blocks_to_load + assert scheduler.staging_buffer_manager.get_num_free_staging_blocks( + ) == num_staging_blocks - num_blocks_to_load + else: + assert "req1" not in scheduler.load_specs + assert "req1" not in scheduler._external_cache_hits + assert "req1" not in scheduler.staging_buffer_manager._blocks_for_load + + def test_update_state_after_alloc(self, scheduler_factory): + """ + Tests that a LoadSpec is correctly updated after block allocation. + """ + scheduler = scheduler_factory() + req_id = "req1" + num_prompt_blocks = 4 + num_matched_blocks = 3 + num_computed_blocks = 2 + num_blocks_to_load = num_matched_blocks - num_computed_blocks + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_matched_tokens = num_matched_blocks * scheduler.block_size + num_tokens_to_load = scheduler.block_size * num_blocks_to_load + + request = create_request(req_id, [0] * num_prompt_tokens, + scheduler.block_size) + + # Setup a pending load + scheduler.load_specs[req_id] = MagicMock( + num_matched_tokens=num_matched_tokens, + num_skip_leading_tokens=num_computed_blocks * scheduler.block_size, + dst_blocks=[-1] * num_blocks_to_load, + src_chunks=[i for i in range(num_blocks_to_load)], + can_load=False) + + # Mock allocated blocks + allocated_blocks = MagicMock(spec=KVCacheBlocks) + allocated_block_ids = [i for i in range(num_prompt_blocks)] + allocated_blocks.get_block_ids.return_value = [allocated_block_ids] + + scheduler.update_state_after_alloc(request, allocated_blocks, + num_tokens_to_load) + + load_spec = scheduler.load_specs[req_id] + assert load_spec.can_load + assert load_spec.dst_blocks == allocated_block_ids[ + num_computed_blocks:num_matched_blocks] + assert req_id in scheduler._reqs_being_loaded + assert len(scheduler._reqs_being_loaded[req_id]) == num_blocks_to_load + + @pytest.mark.parametrize( + "num_computed_tokens, num_matched_tokens, num_prompt_tokens, num_staging_tokens", + [(0, 0, 64, 160), + (0, 32, 64, 160), (16, 32, 64, 160), (0, 64, 64, 160), + (16, 64, 64, 160), (0, 32, 64, 48), (0, 32, 64, 16)]) + def test_build_connector_meta_new_prefill(self, scheduler_factory, + num_computed_tokens, + num_matched_tokens, + num_prompt_tokens, + num_staging_tokens): + """ + Tests metadata generation for a new request (prefill) with no cache hit. + 1. no hit + save 4 blocks + 2. partial hit (no-skip + load 2 blocks) + save 2 blocks + 3. partial hit (skip 1 block + load 1 blocks) + save 2 blocks + 4. full hit (no-skip + load 4 blocks) + no-save + 5. full hit (skip 1 block + load 3 blocks) + no-save + 6. partial hit (no-skip + load 2 blocks) + save 2 blocks + 3 staging blocks limit + 7. partial hit (no-skip + load 2 blocks) + save 2 blocks + 1 staging blocks limit + """ + num_staging_blocks = num_staging_tokens // _DEFAULT_BLOCK_SIZE + scheduler = scheduler_factory( + offload_partial_block_save_behavior="drop", + offload_staging_buffer_tokens=num_staging_tokens, + offload_num_cpu_chunks=100) + + # calculate the groundtruth + num_computed_blocks = num_computed_tokens // scheduler.block_size + num_matched_blocks = num_matched_tokens // scheduler.block_size + num_prompt_blocks = (num_prompt_tokens + scheduler.block_size - + 1) // scheduler.block_size + num_blocks_to_load = num_matched_blocks - num_computed_blocks + # adjustment based on staging_block limitation + if num_blocks_to_load > num_staging_blocks: + num_blocks_to_load = num_staging_blocks + num_matched_blocks = num_blocks_to_load + num_computed_blocks + num_matched_tokens = num_matched_blocks * scheduler.block_size + + remaining_staging_blocks = num_staging_blocks - num_blocks_to_load + num_blocks_to_save = num_prompt_blocks - num_matched_blocks + if num_blocks_to_save > remaining_staging_blocks: + num_blocks_to_save = remaining_staging_blocks + # reconfig staging_buffer limit for save + scheduler.staging_buffer_manager._num_free_blocks = remaining_staging_blocks + num_tokens_in_cache = (num_matched_blocks + + num_blocks_to_save) * scheduler.block_size + + req_id = "req1" + request = create_request(req_id, + list(range(num_prompt_tokens)), + scheduler.block_size, + num_computed_tokens=num_computed_tokens) + request.block_ids = [[i for i in range(num_prompt_blocks)]] + + # init offload_manager state + if num_matched_blocks > 0: + matched_block_hashes = request.block_hashes[:num_matched_blocks] + allocated_chunks, _ = scheduler.offload_manager.allocate_for_save( + matched_block_hashes) + scheduler.offload_manager.complete_save(matched_block_hashes) + # allocated_chunk_ids = [chunk.chunk_id for chunk in allocated_chunks] + # load_src_chunk_ids = allocated_chunk_ids[num_computed_blocks:] + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[request], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={ + "req1": num_prompt_tokens - num_computed_tokens + }, + total_num_scheduled_tokens=num_prompt_tokens - num_computed_tokens, + finished_req_ids=set(), + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=0, + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + # Mock that the scheduler has seen this request + scheduler._unfinished_requests["req1"] = request + scheduler._external_cache_hits["req1"] = num_matched_tokens + if num_blocks_to_load > 0: + scheduler.load_specs[req_id] = MagicMock( + num_matched_tokens=num_matched_tokens, + num_skip_leading_tokens=num_computed_tokens, + dst_blocks=[-1] * num_blocks_to_load, + src_chunks=[i for i in range(num_blocks_to_load)], + can_load=True) + + metadata = scheduler.build_connector_meta(scheduler_output) + + if num_blocks_to_load + num_blocks_to_save == 0: + # no load or store + assert len(metadata.requests_meta) == 0 + else: + req_meta = metadata.requests_meta[0] + assert req_meta.req_id == "req1" + if num_blocks_to_load == 0: + assert req_meta.load_spec is None + else: + # load + assert req_meta.load_spec is not None + # NOTE(jcgu): no need to check details, since they are + # generated by other functions. + if num_blocks_to_save == 0: + assert req_meta.save_spec is None + else: + # save + assert req_meta.save_spec is not None + assert req_meta.save_spec.num_total_tokens == num_tokens_in_cache + assert req_meta.save_spec.num_skip_leading_tokens == num_matched_blocks * scheduler.block_size + assert req_meta.save_spec.src_blocks == request.block_ids[0][ + num_matched_blocks:num_matched_blocks + num_blocks_to_save] + assert len(req_meta.save_spec.dst_chunks) == num_blocks_to_save + assert not req_meta.save_spec.is_final_save + assert "req1" in scheduler.staging_buffer_manager._blocks_for_save + assert scheduler.staging_buffer_manager._blocks_for_save[ + "req1"] == num_blocks_to_save + assert "req1" in scheduler._reqs_being_saved + assert len( + scheduler._reqs_being_saved["req1"]) == num_blocks_to_save + + assert "req1" in scheduler._request_trackers + tracker = scheduler._request_trackers["req1"] + # after creating SaveSpec, we also update tracker.save_watermark + assert tracker.save_watermark == num_tokens_in_cache From e2ac48c41faee126749e426774e7923ca4dd66ec Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 21 Nov 2025 04:21:45 +0000 Subject: [PATCH 16/19] scheduler test Signed-off-by: Juncheng Gu --- .../tpu_offload_connector_scheduler_test.py | 166 +++++++++++++++++- 1 file changed, 162 insertions(+), 4 deletions(-) diff --git a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py index 5bf264b7b..d6c1fa630 100644 --- a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py +++ b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py @@ -9,7 +9,8 @@ from vllm.v1.request import Request from tpu_inference.distributed.offload.tpu_offload_connector import ( - DEFAULT_TPU_OFFLOAD_CPU_CHUNKS, TPUOffloadConnectorScheduler) + DEFAULT_TPU_OFFLOAD_CPU_CHUNKS, RequestTracker, + TPUOffloadConnectorScheduler) _DEFAULT_BLOCK_SIZE = 16 @@ -34,20 +35,21 @@ def create_request( prompt_token_ids: list[int], block_size: int, num_computed_tokens: int = 0, + generated_token_ids: list[int] = [], ) -> Request: """Creates a mock vLLM request object.""" req = MagicMock(spec=Request) req.request_id = request_id req.req_id = request_id # for NewRequestData req.prompt_token_ids = prompt_token_ids - req.all_token_ids = prompt_token_ids - req.num_computed_tokens = num_computed_tokens + req.all_token_ids = prompt_token_ids + generated_token_ids + req.num_computed_tokens = num_computed_tokens + len(generated_token_ids) req.block_size = block_size req.block_ids = [[]] # Mock the block_hashes property to return a list of mock hashes req.block_hashes = [ f"hash_{i}".encode() - for i in range(len(prompt_token_ids) // block_size) + for i in range(len(req.all_token_ids) // block_size) ] return req @@ -340,3 +342,159 @@ def test_build_connector_meta_new_prefill(self, scheduler_factory, tracker = scheduler._request_trackers["req1"] # after creating SaveSpec, we also update tracker.save_watermark assert tracker.save_watermark == num_tokens_in_cache + + @pytest.mark.parametrize("prompt_len, seq_len, decode_save", [(63, 64, 1), + (18, 64, 1), + (18, 64, 0)]) + def test_build_connector_meta_decode_with_save(self, scheduler_factory, + prompt_len, seq_len, + decode_save): + """ + Tests metadata generation for a decode step that triggers a save. + 1. the first decode (hit block boundary) + decode_save (save one block) + 2. th N-th decode (hit block bounary) + decode_save (save one block) + 2. th N-th decode (hit block bounary) + not decode_save (no save) + """ + + scheduler = scheduler_factory( + offload_decode_save=decode_save, + offload_staging_buffer_tokens=_DEFAULT_BLOCK_SIZE * 10, + offload_num_cpu_chunks=10) + + prompt_tokens = list(range(prompt_len)) + generated_tokens = list(range(prompt_len, seq_len)) + req_id = "req1" + request = create_request(req_id, + prompt_token_ids=prompt_tokens, + block_size=scheduler.block_size, + num_computed_tokens=seq_len, + generated_token_ids=generated_tokens) + num_blocks = (seq_len + scheduler.block_size - + 1) // scheduler.block_size + request.block_ids = [i for i in range(num_blocks)] + + if decode_save == 1: + # the last token in seq hasn't been computed (kv) yet + num_saved_tokens = ( + (seq_len - 1) // scheduler.block_size) * scheduler.block_size + else: + num_saved_tokens = (prompt_len // + scheduler.block_size) * scheduler.block_size + + # Setup initial state + # request tracker only tracks the computed tokens + tracker = RequestTracker(req_id="req1", + prompt_len=prompt_len, + token_ids=request.all_token_ids[:-1], + block_ids=request.block_ids, + save_watermark=num_saved_tokens) + + scheduler._request_trackers["req1"] = tracker + scheduler._unfinished_requests["req1"] = request + + # Simulate a decode step + cached_req_data = CachedRequestData.make_empty() + cached_req_data.req_ids = ["req1"] + cached_req_data.new_block_ids = ([], ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={"req1": 1}, + total_num_scheduled_tokens=1, + finished_req_ids=set(), + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=0, + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + metadata = scheduler.build_connector_meta(scheduler_output) + + if seq_len % scheduler.block_size != 0 or decode_save != 1: + # no save when there is no new full computed block + assert len(metadata.requests_meta) == 0 + else: + req_meta = metadata.requests_meta[0] + # save spec + assert req_meta.req_id == "req1" + assert req_meta.load_spec is None + assert req_meta.save_spec is not None + assert req_meta.save_spec.num_total_tokens == seq_len + assert req_meta.save_spec.num_skip_leading_tokens == num_saved_tokens + assert req_meta.save_spec.src_blocks == [num_blocks - 1] + assert len(req_meta.save_spec.dst_chunks) == 1 + assert not req_meta.save_spec.is_final_save + # staging buffer + assert "req1" in scheduler.staging_buffer_manager._blocks_for_save + assert scheduler.staging_buffer_manager._blocks_for_save[ + "req1"] == 1 + # chunk_id for save + assert "req1" in scheduler._reqs_being_saved + assert len(scheduler._reqs_being_saved["req1"]) == 1 + + assert tracker.save_watermark == seq_len + + def test_build_connector_meta_finished_request(self, scheduler_factory): + """ + Tests metadata generation for a finished request. + When using request's default block hash (fully-computed blocks only), + a finished request either saves the last full block in their last + decode step, or given up the last partial block; when it's treated as a + finished request, there is no blocks to save. + + """ + + scheduler = scheduler_factory(offload_decode_save=1) + prompt_len = scheduler.block_size + 4 + final_seq_len = scheduler.block_size * 2 + 3 + prompt_tokens = list(range(prompt_len)) + generated_tokens = list(range(prompt_len, final_seq_len)) + req_id = "req1" + request = create_request(req_id, + prompt_token_ids=prompt_tokens, + block_size=scheduler.block_size, + num_computed_tokens=final_seq_len, + generated_token_ids=generated_tokens) + num_blocks = (final_seq_len + scheduler.block_size - + 1) // scheduler.block_size + request.block_ids = [i for i in range(num_blocks)] + + num_saved_tokens = (final_seq_len // + scheduler.block_size) * scheduler.block_size + + # Setup initial state + tracker = RequestTracker(req_id="req1", + prompt_len=prompt_len, + token_ids=request.all_token_ids[:-1], + block_ids=request.block_ids, + save_watermark=num_saved_tokens) + scheduler._request_trackers["req1"] = tracker + scheduler._unfinished_requests["req1"] = request + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + finished_req_ids={"req1"}, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=0, + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + metadata = scheduler.build_connector_meta(scheduler_output) + + assert req_id not in scheduler._unfinished_requests + assert req_id not in scheduler._request_trackers + assert len(metadata.requests_meta) == 1 + req_meta = metadata.requests_meta[0] + assert req_meta.save_spec is not None + assert req_meta.save_spec.is_final_save + assert req_meta.save_spec.skip_save + assert req_meta.save_spec.src_blocks == [] + assert req_meta.save_spec.dst_chunks == [] From 64343d70d4d1c6647a16ae339db42ad456e5021f Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 21 Nov 2025 05:59:44 +0000 Subject: [PATCH 17/19] rename test files Signed-off-by: Juncheng Gu --- .../gke/pod_tpu_host_offload_unit_tests.yaml | 3 ++- ...est.py => tpu_offload_cpu_backend_test.py} | 4 ++-- .../offload/tpu_offload_connector.py | 21 +++++++++---------- 3 files changed, 14 insertions(+), 14 deletions(-) rename tests/distributed/offload/{cpu_backend_test.py => tpu_offload_cpu_backend_test.py} (96%) diff --git a/examples/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/gke/pod_tpu_host_offload_unit_tests.yaml index 8daf6c035..69f14cabd 100644 --- a/examples/gke/pod_tpu_host_offload_unit_tests.yaml +++ b/examples/gke/pod_tpu_host_offload_unit_tests.yaml @@ -17,8 +17,9 @@ spec: command: - /bin/bash - -c - - "pytest -sv tests/distributed/offload/cpu_backend_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_cpu_backend_test.py" - "pytest -sv tests/distributed/offload/tpu_offload_connector_worker_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_connector_scheduler_test.py" - "pytest -sv tests/distributed/offload/tpu_offload_utils_test.py" - "pytest -sv tests/distributed/offload/tpu_offload_manager_test.py" - "pytest -sv tests/distributed/offload/tpu_offload_accuracy_test.py" diff --git a/tests/distributed/offload/cpu_backend_test.py b/tests/distributed/offload/tpu_offload_cpu_backend_test.py similarity index 96% rename from tests/distributed/offload/cpu_backend_test.py rename to tests/distributed/offload/tpu_offload_cpu_backend_test.py index d74e4d2e1..e845ef688 100644 --- a/tests/distributed/offload/cpu_backend_test.py +++ b/tests/distributed/offload/tpu_offload_cpu_backend_test.py @@ -40,7 +40,7 @@ def test_add_and_get(self): assert retrieved_list_value == value_list assert backend.current_size_bytes == 50 + 20 + 30 - assert backend.num_occupied_cpu_chunks == 2 + assert backend.num_saved_cpu_chunks == 2 def test_add_invalid_chunk_id(self): """Verifies that adding a value with an invalid chunk_id raises a ValueError.""" @@ -50,7 +50,7 @@ def test_add_invalid_chunk_id(self): with pytest.raises(ValueError): backend.add(CpuChunkId(-1), value) - assert backend.num_occupied_cpu_chunks == 0 + assert backend.num_saved_cpu_chunks == 0 def test_reclaim_unoccupied_chunks(self): """Tests that unoccupied chunks are reclaimed correctly.""" diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index a8ce1e462..0b1fdb643 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -132,9 +132,9 @@ BLOCK_SIZE_BUCKETS = [1, 2, 4, 8, 16] # we keep our operations at vllm's block granularity, -# and provide the following three preferences when handling +# and want to provide the following three preferences when handling # the last partial block during save: -# 1. drop: drop the entire partial block +# 1. [supported] drop: drop the entire partial block # 2. pad: pad to a full block # 3. dynamic: keep the partial block as is. PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal["drop", "pad", "dynamic"] @@ -496,10 +496,11 @@ def __init__(self, vllm_config: "VllmConfig"): chunk_size=self.block_size) self.decode_save = os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0") == "1" - # NOTE(jcgu): currently, let's nail on chunk_size == block_size + # NOTE(jcgu): currently, let's make chunk_size == block_size # chunk_size == n * block_size lead to # 1. multi-size chunks - # 2. complicated resize (split, concatenate) operations based on real-chunk-size in save and load + # 2. complicated resize (split, concatenate) operations due to + # real-chunk-size in save and load self.cpu_chunk_size = self.block_size # define partial_block saving behavior @@ -515,6 +516,10 @@ def __init__(self, vllm_config: "VllmConfig"): self.partial_block_save_behavior == "drop" elif self.partial_block_dynamic_pad_lower_limit >= self.block_size: self.partial_block_save_behavior == "pad" + logger.info( + f" partial_block_save_behavior is configed to {self.partial_block_save_behavior}, but we only support drop now." + ) + self.partial_block_save_behavior = "drop" # config staging buffer # NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler @@ -541,13 +546,7 @@ def __init__(self, vllm_config: "VllmConfig"): def _get_request_block_hashes(self, req: "Request") -> list[BlockHash]: # request's original block_hashes do not include the last partial block - - # TODO(jcgu): switch back to self-hash function - # prompt_token_ids = req.prompt_token_ids - # request_keys = self.token_processor.process_tokens(prompt_token_ids) - # hashes = [hash for _, _, hash in request_keys] - # return hashes - + # TODO(jcgu): switch back to token_processor return req.block_hashes def get_num_new_matched_tokens( From 616958bf6dbc63212ead9286cbaf25ffe52c450b Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 21 Nov 2025 23:01:56 +0000 Subject: [PATCH 18/19] fix file path Signed-off-by: Juncheng Gu --- examples/gke/benchmarks/deploy-cpu-offload.yaml | 2 +- examples/gke/pod_tpu_commons_cpu_offload.yaml | 2 +- examples/gke/pod_tpu_commons_cpu_offload_verification.yaml | 2 +- tests/distributed/offload/tpu_offload_accuracy_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/gke/benchmarks/deploy-cpu-offload.yaml b/examples/gke/benchmarks/deploy-cpu-offload.yaml index a93ccd3fe..5bcd573b2 100644 --- a/examples/gke/benchmarks/deploy-cpu-offload.yaml +++ b/examples/gke/benchmarks/deploy-cpu-offload.yaml @@ -21,7 +21,7 @@ spec: imagePullPolicy: Always command: ["/bin/sh", "-c"] args: - - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUOffloadConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector_local\"}' --port 8000 --max_num_batched_tokens 2048 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" + - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUOffloadConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.distributed.offload.tpu_offload_connector\"}' --port 8000 --max_num_batched_tokens 2048 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/examples/gke/pod_tpu_commons_cpu_offload.yaml b/examples/gke/pod_tpu_commons_cpu_offload.yaml index 2cef3ccfa..49bb437dc 100644 --- a/examples/gke/pod_tpu_commons_cpu_offload.yaml +++ b/examples/gke/pod_tpu_commons_cpu_offload.yaml @@ -18,7 +18,7 @@ spec: - --tensor_parallel_size=8 - --max_model_len=1024 - --kv-transfer-config - - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.tpu_connector_local","kv_role":"kv_both"}' + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.offload.tpu_offload_connector","kv_role":"kv_both"}' env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml b/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml index b2eb566c6..f9e7c7c41 100644 --- a/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml +++ b/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml @@ -25,7 +25,7 @@ spec: - --max_model_len=1024 - --seed=42 - --kv-transfer-config - - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.tpu_connector_local","kv_role":"kv_both"}' + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.offload.tpu_offload_connector","kv_role":"kv_both"}' env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/tests/distributed/offload/tpu_offload_accuracy_test.py b/tests/distributed/offload/tpu_offload_accuracy_test.py index 3f9897443..0059c4bd9 100644 --- a/tests/distributed/offload/tpu_offload_accuracy_test.py +++ b/tests/distributed/offload/tpu_offload_accuracy_test.py @@ -35,7 +35,7 @@ def sampling_config(): @pytest.fixture def kv_transfer_config(): - """use from tpu_connector_local""" + """use TPUOffloadConnector""" return KVTransferConfig( kv_connector="TPUOffloadConnector", kv_role="kv_both", From 1be2866702d812b8f9689ebd40015e13d1b02d0b Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Sat, 22 Nov 2025 05:00:25 +0000 Subject: [PATCH 19/19] fix kv_transfer_stat reduce bug Signed-off-by: Juncheng Gu --- .../distributed/offload/tpu_offload_connector.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index 0b1fdb643..8e85d16ce 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -312,8 +312,12 @@ def reduce(self) -> dict[str, int | float]: "Num finished load blocks ": 0, } - finished_save_chunks = sum(self.data["finished_save_chunks"].values()) - finished_load_chunks = sum(self.data["finished_load_chunks"].values()) + finished_save_chunks = sum( + len(chunk_list) + for chunk_list in self.data["finished_save_chunks"].values()) + finished_load_chunks = sum( + len(chunk_list) + for chunk_list in self.data["finished_load_chunks"].values()) return { "Num finished save chunks ": finished_save_chunks,