Skip to content

Commit 66b5840

Browse files
Flink-dddtjtanaa
andauthored
[Bugfix][sleepmode][fp8 kv cache]: Fix FP8 KV cache + sleep(level=2) gibberish output (#28783)
Signed-off-by: vensen <vensenmu@gmail.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
1 parent 82c795d commit 66b5840

File tree

4 files changed

+94
-2
lines changed

4 files changed

+94
-2
lines changed

tests/basic_correctness/test_cumem.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.platforms import current_platform
1212
from vllm.utils.mem_constants import GiB_bytes
1313

14-
from ..utils import create_new_process_for_each_test
14+
from ..utils import create_new_process_for_each_test, requires_fp8
1515

1616

1717
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
@@ -243,3 +243,34 @@ async def test():
243243
assert output.outputs[0].text == output2.outputs[0].text
244244

245245
asyncio.run(test())
246+
247+
248+
@requires_fp8
249+
def test_deep_sleep_fp8_kvcache():
250+
GiB_bytes = 1 << 30
251+
model = "Qwen/Qwen2-0.5B"
252+
used_bytes_baseline = current_platform.get_current_memory_usage()
253+
254+
llm = LLM(model, enable_sleep_mode=True, kv_cache_dtype="fp8")
255+
prompt = "How are you?"
256+
sampling_params = SamplingParams(temperature=0, max_tokens=10)
257+
output = llm.generate(prompt, sampling_params)
258+
259+
# Put the engine to deep sleep
260+
llm.sleep(level=2)
261+
262+
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
263+
assert used_bytes < 3 * GiB_bytes
264+
265+
llm.wake_up(tags=["weights"])
266+
llm.collective_rpc("reload_weights")
267+
268+
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
269+
assert used_bytes < 4 * GiB_bytes
270+
271+
# now allocate kv cache and cuda graph memory
272+
llm.wake_up(tags=["kv_cache"])
273+
output2 = llm.generate(prompt, sampling_params)
274+
275+
# cmp output
276+
assert output[0].outputs[0].text == output2[0].outputs[0].text

tests/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,13 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
10751075
)
10761076

10771077

1078+
requires_fp8 = pytest.mark.skipif(
1079+
not current_platform.supports_fp8(),
1080+
reason="FP8 is not supported on this GPU (requires Hopper or "
1081+
"Ada architecture, compute capability 8.9+)",
1082+
)
1083+
1084+
10781085
def large_gpu_test(*, min_gb: int):
10791086
"""
10801087
Decorate a test to be skipped if no GPU is available or it does not have

vllm/v1/worker/gpu_model_runner.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
AttentionType,
2626
MultipleOf,
2727
)
28-
from vllm.attention.layer import Attention
28+
from vllm.attention.layer import Attention, MLAAttention
2929
from vllm.compilation.counter import compilation_counter
3030
from vllm.compilation.cuda_graph import CUDAGraphWrapper
3131
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
@@ -602,6 +602,50 @@ def reset_mm_cache(self) -> None:
602602
if self.mm_budget:
603603
self.mm_budget.reset_cache()
604604

605+
@torch.inference_mode()
606+
def init_fp8_kv_scales(self) -> None:
607+
"""
608+
Re-initialize the KV cache and FP8 scales after waking from sleep.
609+
1. Zero out the KV cache tensors to remove garbage data from re-allocation.
610+
2. Reset Attention layer scaling factors (_k_scale, _v_scale) to 1.0.
611+
If these are left at 0.0 (default after wake_up), all KV cache values
612+
become effectively zero, causing gibberish output.
613+
"""
614+
if not self.cache_config.cache_dtype.startswith("fp8"):
615+
return
616+
617+
kv_caches = getattr(self, "kv_caches", [])
618+
for cache_tensor in kv_caches:
619+
if cache_tensor is not None:
620+
cache_tensor.zero_()
621+
622+
k_attr_names = ("_k_scale", "k_scale")
623+
v_attr_names = ("_v_scale", "v_scale")
624+
625+
attn_layers = self.compilation_config.static_forward_context
626+
for name, module in attn_layers.items():
627+
if isinstance(module, (Attention, MLAAttention)):
628+
# TODO: Generally, scale is 1.0 if user uses on-the-fly fp8
629+
# kvcache quant. However, to get better accuracy, compression
630+
# frameworks like llm-compressors allow users to tune the
631+
# scale. We may need to restore the specific calibrated scales
632+
# here in the future.
633+
k_scale_val, v_scale_val = 1.0, 1.0
634+
635+
# Processing K Scale
636+
for attr in k_attr_names:
637+
if hasattr(module, attr):
638+
param = getattr(module, attr)
639+
if isinstance(param, torch.Tensor):
640+
param.fill_(k_scale_val)
641+
642+
# Processing V Scale
643+
for attr in v_attr_names:
644+
if hasattr(module, attr):
645+
param = getattr(module, attr)
646+
if isinstance(param, torch.Tensor):
647+
param.fill_(v_scale_val)
648+
605649
def _get_positions(self, num_tokens: Any):
606650
if isinstance(num_tokens, int):
607651
if self.uses_mrope:

vllm/v1/worker/gpu_worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ def wake_up(self, tags: list[str] | None = None) -> None:
141141
buffer.data.copy_(self._sleep_saved_buffers[name].data)
142142
self._sleep_saved_buffers = {}
143143

144+
# If the KV cache has just been woken up,
145+
# the internal state of cache_engine must be reset,
146+
# especially the FP8 scaling factor.
147+
if (
148+
(tags is None or "kv_cache" in tags)
149+
and self.cache_config.cache_dtype.startswith("fp8")
150+
and hasattr(self.model_runner, "init_fp8_kv_scales")
151+
):
152+
self.model_runner.init_fp8_kv_scales()
153+
144154
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
145155
if self.vllm_config.model_config.enable_sleep_mode:
146156
from vllm.device_allocator.cumem import CuMemAllocator

0 commit comments

Comments
 (0)