diff --git a/src/scope/core/pipelines/wan2_1/blocks/prepare_video_latents.py b/src/scope/core/pipelines/wan2_1/blocks/prepare_video_latents.py index ad6ba0240..e35d27f52 100644 --- a/src/scope/core/pipelines/wan2_1/blocks/prepare_video_latents.py +++ b/src/scope/core/pipelines/wan2_1/blocks/prepare_video_latents.py @@ -73,13 +73,32 @@ def intermediate_outputs(self) -> list[OutputParam]: OutputParam("generator", description="Random number generator"), ] + # Maximum number of frames to encode at once to avoid CUDA OOM. + VAE_ENCODE_CHUNK_SIZE = 8 + @torch.no_grad() def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState]: block_state = self.get_block_state(state) - # Encode frames to latents using VAE - # VAE returns [B, F, C, H, W] which is what DenoiseBlock/Generator expect - latents = components.vae.encode_to_latent(block_state.video) + # Encode frames to latents using VAE in chunks to prevent CUDA OOM. + # VAE expects [B, C, F, H, W] and returns [B, F, C, H, W]. + video = block_state.video + # video shape: [B, C, F, H, W] + num_frames = video.shape[2] + chunk_size = self.VAE_ENCODE_CHUNK_SIZE + + if num_frames <= chunk_size: + latents = components.vae.encode_to_latent(video) + else: + latent_chunks = [] + for start in range(0, num_frames, chunk_size): + end = min(start + chunk_size, num_frames) + chunk = video[:, :, start:end, :, :] + latent_chunk = components.vae.encode_to_latent(chunk) + latent_chunks.append(latent_chunk) + torch.cuda.empty_cache() + # Concatenate along frames dim (dim=1 in latent space [B, F, C, H, W]) + latents = torch.cat(latent_chunks, dim=1) # The default param for InputParam does not work right now # The workaround is to set the default values here