From e58d9f6dc4dc83b7c05959777dc93504a06c14d8 Mon Sep 17 00:00:00 2001 From: livepeer-robot Date: Fri, 20 Feb 2026 21:53:05 +0000 Subject: [PATCH] fix: handle zero-token KV cache edge case Add guard clauses to skip processing when the token/frame count is zero, preventing tensor dimension mismatches during KV cache expansion (e.g. 'expanded size (0) must match existing size'). Guards added in: - DenoiseBlock: skip denoising loop when num_frames == 0 - CleanKVCacheBlock: skip cache cleaning when num_frames == 0 - RecacheFramesBlock: skip recaching when num_recache_frames == 0 - causal_rope_apply: pass through when seq_len == 0 Fixes daydreamlive/scope#500 --- .../core/pipelines/longlive/blocks/recache_frames.py | 7 +++++++ src/scope/core/pipelines/longlive/modules/causal_model.py | 5 +++++ src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py | 6 ++++++ src/scope/core/pipelines/wan2_1/blocks/denoise.py | 8 ++++++++ 4 files changed, 26 insertions(+) diff --git a/src/scope/core/pipelines/longlive/blocks/recache_frames.py b/src/scope/core/pipelines/longlive/blocks/recache_frames.py index 229b666e7..adcb23d41 100644 --- a/src/scope/core/pipelines/longlive/blocks/recache_frames.py +++ b/src/scope/core/pipelines/longlive/blocks/recache_frames.py @@ -165,6 +165,13 @@ def __call__(self, components, state: PipelineState) -> PipelineState: num_recache_frames = min( block_state.current_start_frame, components.config.local_attn_size ) + + # Guard: skip recaching when there are zero frames to avoid + # KV cache tensor dimension mismatches (e.g. expand size 0 vs existing cache size) + if num_recache_frames == 0: + self.set_block_state(state, block_state) + return components, state + recache_start = block_state.current_start_frame - num_recache_frames recache_frames = ( block_state.recache_buffer[:, -num_recache_frames:] diff --git a/src/scope/core/pipelines/longlive/modules/causal_model.py b/src/scope/core/pipelines/longlive/modules/causal_model.py index 8607855bc..b20492dd2 100644 --- a/src/scope/core/pipelines/longlive/modules/causal_model.py +++ b/src/scope/core/pipelines/longlive/modules/causal_model.py @@ -43,6 +43,11 @@ def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): for i, (f, h, w) in enumerate(grid_sizes.tolist()): seq_len = f * h * w + # Guard: if any grid dimension is zero, skip rope and pass through + if seq_len == 0: + output.append(x[i]) + continue + # precompute multipliers x_i = torch.view_as_complex( x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2) diff --git a/src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py b/src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py index c5941368c..349156bb0 100644 --- a/src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py +++ b/src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py @@ -107,6 +107,12 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState generator_param = next(components.generator.parameters()) _, num_frames, _, _, _ = block_state.latents.shape + + # Guard: skip cache cleaning when there are zero frames to avoid + # KV cache tensor dimension mismatches (e.g. expand size 0 vs existing cache size) + if num_frames == 0: + self.set_block_state(state, block_state) + return components, state current_end_frame = block_state.current_start_frame + num_frames # This is defined to give us timestep = 0 while matching shape expected by the generator. diff --git a/src/scope/core/pipelines/wan2_1/blocks/denoise.py b/src/scope/core/pipelines/wan2_1/blocks/denoise.py index 2c344680b..84f647391 100644 --- a/src/scope/core/pipelines/wan2_1/blocks/denoise.py +++ b/src/scope/core/pipelines/wan2_1/blocks/denoise.py @@ -150,6 +150,14 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState noise = block_state.latents batch_size = noise.shape[0] num_frames = noise.shape[1] + + # Guard: skip denoising when there are zero frames to avoid + # KV cache tensor dimension mismatches (e.g. expand size 0 vs existing cache size) + if num_frames == 0: + block_state.latents = noise + self.set_block_state(state, block_state) + return components, state + denoising_step_list = block_state.current_denoising_step_list.clone() conditional_dict = {"prompt_embeds": block_state.conditioning_embeds}