Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/scope/core/pipelines/longlive/blocks/recache_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
num_recache_frames = min(
block_state.current_start_frame, components.config.local_attn_size
)

# Guard: skip recaching when there are zero frames to avoid
# KV cache tensor dimension mismatches (e.g. expand size 0 vs existing cache size)
if num_recache_frames == 0:
self.set_block_state(state, block_state)
return components, state

recache_start = block_state.current_start_frame - num_recache_frames
recache_frames = (
block_state.recache_buffer[:, -num_recache_frames:]
Expand Down
5 changes: 5 additions & 0 deletions src/scope/core/pipelines/longlive/modules/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w

# Guard: if any grid dimension is zero, skip rope and pass through
if seq_len == 0:
output.append(x[i])
continue

# precompute multipliers
x_i = torch.view_as_complex(
x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
Expand Down
6 changes: 6 additions & 0 deletions src/scope/core/pipelines/wan2_1/blocks/clean_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState
generator_param = next(components.generator.parameters())

_, num_frames, _, _, _ = block_state.latents.shape

# Guard: skip cache cleaning when there are zero frames to avoid
# KV cache tensor dimension mismatches (e.g. expand size 0 vs existing cache size)
if num_frames == 0:
self.set_block_state(state, block_state)
return components, state
current_end_frame = block_state.current_start_frame + num_frames

# This is defined to give us timestep = 0 while matching shape expected by the generator.
Expand Down
8 changes: 8 additions & 0 deletions src/scope/core/pipelines/wan2_1/blocks/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState
noise = block_state.latents
batch_size = noise.shape[0]
num_frames = noise.shape[1]

# Guard: skip denoising when there are zero frames to avoid
# KV cache tensor dimension mismatches (e.g. expand size 0 vs existing cache size)
if num_frames == 0:
block_state.latents = noise
self.set_block_state(state, block_state)
return components, state

denoising_step_list = block_state.current_denoising_step_list.clone()

conditional_dict = {"prompt_embeds": block_state.conditioning_embeds}
Expand Down