Skip to content

[Feature][KVCache] Implement Cache Manager V1 with multi-GPU support and MTP KV Cache#2

Closed
kevincheng2 wants to merge 11 commits intodevelopfrom
feature/cache_manager_v1_rebase
Closed

[Feature][KVCache] Implement Cache Manager V1 with multi-GPU support and MTP KV Cache#2
kevincheng2 wants to merge 11 commits intodevelopfrom
feature/cache_manager_v1_rebase

Conversation

@kevincheng2
Copy link
Copy Markdown
Owner

Motivation

实现全新的 Cache Manager V1 三级缓存管理架构(fastdeploy/cache_manager/v1/),支持 KV Cache 在 GPU 显存、CPU 内存、持久化存储之间的异步换入换出,提升 KV Cache 命中率和整体吞吐。

同时修复了多卡场景下 CuPy stream/Event 设备绑定错误,以及多模态 hash 边界判断 Bug。

Modifications

新增核心模块(fastdeploy/cache_manager/v1/)

模块 说明
base.py KVCacheBase 抽象基类,统一初始化 model_configcache_configquant_configparallel_config,消除子类重复代码
block_pool.py BlockPool,管理 Device/Host block 的分配与释放,支持 LRU 驱逐
cache_controller.py CacheController,三级缓存 block 状态机(DEVICE / HOST / SWAP_TO_HOST / SWAP_TO_DEVICE / LOADING_FROM_STORAGE / DELETING),协调 swap/backup/storage 操作
cache_manager.py CacheManager,对外统一接口,封装 prefix cache 匹配、block 分配/释放、swap 调度
cache_utils.py get_block_hash_extra_keys:计算包含多模态 token 的 block hash extra keys;LayerDoneCounter:基于 CUDA Event 追踪逐层传输完成状态,支持 Compute-Transfer Overlap
metadata.py 完整元数据定义:BlockNodeMatchResultCacheSwapMetadataPDTransferMetadataStorageMetadataTransferResult;注释全部翻译为英文
radix_tree.py RadixTree,基于 Radix Tree 的 prefix cache 索引,支持 write_through_selective 备份策略(按 hit_count 阈值触发 host 备份)
transfer_manager.py CacheTransferManager,管理 Host↔Device 异步 swap,基于独立 CuPy stream 实现 Compute-Transfer Overlap

CUDA 算子优化(custom_ops/gpu_ops/)

  • swap_cache_optimized.cu:提供三种 swap 算子:swap_cache_per_layer(单层同步)、swap_cache_per_layer_async(单层异步,无 cudaStreamSynchronize)、swap_cache_all_layers(多层 swap);连续 block 使用 cudaMemcpyAsync 快速路径,非连续 block 使用 warp-level PTX non-temporal 访存;删除 swap_cache_all_layers_batch(batch 上传算子,由 cupy device context 替代)
  • swap_cache_batch.cu:删除 cudaSetDevice 调用,由上层 cupy context 统一保证设备正确性

MTP KV Cache 支持

支持 cache_manager_v1 下的 MTP(Multi-Token Prediction)KV Cache 初始化,以及多模态 token 的 block hash 计算。

Bug 修复

多 GPU 设备绑定修复transfer_manager.py

多卡场景下,CuPy stream 和 Event 创建未指定目标 GPU,导致流绑定到错误设备,swap 操作失败。修复方式:在所有 CuPy stream/Event 的创建与使用处包装 with cp.cuda.Device(device_id): context,覆盖 __init___swap_all_layers_async_swap_per_layer_asyncrecord_input_stream_event 等方法。

多模态 hash 边界修复cache_utils.py

get_block_hash_extra_keys 中使用严格小于 < 判断多模态 item 是否在 block 开始位置之前,导致恰好结束于 block 起始位置的 item 被错误包含。修复为 <=

测试覆盖

新增/扩展单元测试:

  • tests/cache_manager/v1/test_cache_utils.pyget_block_hash_extra_keys 全场景测试(边界、多模态重叠、跨 block 等 20+ 用例)
  • tests/cache_manager/v1/test_cache_controller.py:CacheController block 状态机测试
  • tests/cache_manager/v1/test_cache_manager.py:CacheManager 分配/释放/前缀匹配全链路测试
  • tests/cache_manager/v1/test_radix_tree.py:RadixTree 备份候选、LRU 驱逐测试
  • tests/cache_manager/v1/test_transfer_manager.py:TransferManager 多 GPU 场景测试
  • tests/engine/test_request.py:Request cache 字段(block_hasherprompt_hashescache_swap_metadata)测试

Usage or Command

Cache Manager V1 通过 FDConfig.cache_config 控制,相关配置参数:

参数 类型 默认值 说明
cache_config.total_block_num int - GPU device block 总数
cache_config.num_cpu_blocks int 0 CPU host block 总数;0 表示禁用 host cache
cache_config.block_size int 64 每个 block 的 token 数
cache_config.enable_prefix_caching bool False 是否启用 prefix cache(RadixTree 索引)
cache_config.enable_host_cache bool False 是否启用 Host 二级缓存(CPU 内存换出)
cache_config.write_through_policy str "write_through_all" Host 备份策略:write_through_all 全部备份;write_through_selective 按阈值备份
cache_config.backup_hit_threshold int 2 write_through_selective 模式触发备份的最小命中次数
cache_config.enable_storage_cache bool False 是否启用 Storage 三级缓存(持久化存储)
cache_config.enable_compute_transfer_overlap bool False 是否启用 Compute-Transfer Overlap(逐层传输)

启动服务示例(开启 host cache 及 prefix caching):

# 编译算子(如修改了 .cu 文件)
bash build.sh 1 python false [90] 1

# 启动推理服务,开启 host cache
python -m fastdeploy.entrypoints.openai.api_server \
    --model /path/to/model \
    --enable-prefix-caching \
    --num-cpu-blocks 2048 \
    --tensor-parallel-size 4

运行单元测试:

source .venv/py310/bin/activate
cd baidu/FastDeploy
python -m pytest tests/cache_manager/v1/ -vv -s

Checklist

  • Add at least a tag in the PR title.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

kevincheng2 and others added 11 commits March 30, 2026 21:04
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When a node transitions from SWAP_TO_DEVICE to DEVICE via
complete_swap_to_device, it was not being added to the
_evictable_device set. This caused nodes with ref_count=0 to
become "orphaned" - not appearing in any evictable set despite
having cache_status=DEVICE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add new cache_manager.py with cache management functionality
- Add radix_tree.py for prefix caching
- Update block_pool.py and metadata.py
- Update request.py and resource_manager_v1.py for scheduling
- Update gpu_model_runner.py for GPU model execution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add CacheController class for cache management
- Update config.py with cache related configurations
- Refactor gpu_model_runner.py for improved cache handling
…Meta

## Motivation

修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题,
并清理 ForwardMeta 中已废弃的 cache_controller 字段。

## Modifications

- fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids,
  修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl)
- refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从
  cache_controller 改为 cache_utils(正确来源)
- refactor: ForwardMeta 移除废弃的 cache_controller 字段
- refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句
- test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
## Motivation

对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。

## Modifications

- 重构 transfer_manager.py,大幅精简代码逻辑
- 优化 swap_cache_optimized.cu GPU 算子实现
- 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题
- 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py
- 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用
- 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager)
- 调整 config.py 中相关配置项
## Motivation

在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由
CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block
hash 未携带 multimodal extra_keys 的问题。

## Modifications

- `cache_controller.py`
  - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache,
    并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层
  - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证
    Host Cache 也为 MTP 分配足够空间
  - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用)

- `cache_utils.py`
  - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息,
    对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑
  - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys,
    修复多模态场景 prefix cache 命中率不准的问题

- `spec_decode/mtp.py`
  - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager
    路径下重复初始化 MTP KV Cache

- `gpu_model_runner.py`
  - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用
    `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建
  - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1`
    标志,跳过重复的 proposer.initialize_kv_cache 调用

## Usage or Command

```bash
# 启动支持 MTP + cache_manager_v1 的推理服务(示例)
bash run.sh
```
…atch ops

1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations
   with cp.cuda.Device(device_id) context to ensure streams/events are bound
   to the correct device, preventing cross-device errors in multi-GPU setups.

2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now).

3. Remove swap_cache_all_layers_batch op: simplified the implementation by
   removing the batch upload variant; all-layer transfers now use the standard
   swap_cache_all_layers with cupy device context.

4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change
   strict less-than (<) to less-than-or-equal (<=) so that multimodal items
   ending exactly at block start are correctly excluded.

5. Extract config fields to KVCacheBase: model_config, cache_config,
   quant_config, parallel_config are now set in the base class __init__ to
   avoid duplication in CacheController and CacheManager subclasses.

6. Translate metadata.py docstrings from Chinese to English for broader
   contributor accessibility.

7. Add test_cache_utils.py: comprehensive unit tests for
   get_block_hash_extra_keys covering all boundary and overlap scenarios.

8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py
   backup candidate tests, test_transfer_manager.py and test_cache_manager.py
   multi-GPU and concurrent operation tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@kevincheng2
Copy link
Copy Markdown
Owner Author

Closing: PR should target upstream PaddlePaddle/FastDeploy, not this fork.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant