diff --git a/src/scope/core/pipelines/longlive/blocks/recache_frames.py b/src/scope/core/pipelines/longlive/blocks/recache_frames.py index 229b666e7..adcb23d41 100644 --- a/src/scope/core/pipelines/longlive/blocks/recache_frames.py +++ b/src/scope/core/pipelines/longlive/blocks/recache_frames.py @@ -165,6 +165,13 @@ def __call__(self, components, state: PipelineState) -> PipelineState: num_recache_frames = min( block_state.current_start_frame, components.config.local_attn_size ) + + # Guard: skip recaching when there are zero frames to avoid + # KV cache tensor dimension mismatches (e.g. expand size 0 vs existing cache size) + if num_recache_frames == 0: + self.set_block_state(state, block_state) + return components, state + recache_start = block_state.current_start_frame - num_recache_frames recache_frames = ( block_state.recache_buffer[:, -num_recache_frames:] diff --git a/src/scope/core/pipelines/longlive/modules/causal_model.py b/src/scope/core/pipelines/longlive/modules/causal_model.py index 8607855bc..b20492dd2 100644 --- a/src/scope/core/pipelines/longlive/modules/causal_model.py +++ b/src/scope/core/pipelines/longlive/modules/causal_model.py @@ -43,6 +43,11 @@ def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): for i, (f, h, w) in enumerate(grid_sizes.tolist()): seq_len = f * h * w + # Guard: if any grid dimension is zero, skip rope and pass through + if seq_len == 0: + output.append(x[i]) + continue + # precompute multipliers x_i = torch.view_as_complex( x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2) diff --git a/src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py b/src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py index c5941368c..349156bb0 100644 --- a/src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py +++ b/src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py @@ -107,6 +107,12 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState generator_param = next(components.generator.parameters()) _, num_frames, _, _, _ = block_state.latents.shape + + # Guard: skip cache cleaning when there are zero frames to avoid + # KV cache tensor dimension mismatches (e.g. expand size 0 vs existing cache size) + if num_frames == 0: + self.set_block_state(state, block_state) + return components, state current_end_frame = block_state.current_start_frame + num_frames # This is defined to give us timestep = 0 while matching shape expected by the generator. diff --git a/src/scope/core/pipelines/wan2_1/blocks/denoise.py b/src/scope/core/pipelines/wan2_1/blocks/denoise.py index 2c344680b..84f647391 100644 --- a/src/scope/core/pipelines/wan2_1/blocks/denoise.py +++ b/src/scope/core/pipelines/wan2_1/blocks/denoise.py @@ -150,6 +150,14 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState noise = block_state.latents batch_size = noise.shape[0] num_frames = noise.shape[1] + + # Guard: skip denoising when there are zero frames to avoid + # KV cache tensor dimension mismatches (e.g. expand size 0 vs existing cache size) + if num_frames == 0: + block_state.latents = noise + self.set_block_state(state, block_state) + return components, state + denoising_step_list = block_state.current_denoising_step_list.clone() conditional_dict = {"prompt_embeds": block_state.conditioning_embeds}