From 65e5ae1a14386c71e4116f33a6d2e709933b00f4 Mon Sep 17 00:00:00 2001 From: livepeer-robot Date: Fri, 20 Feb 2026 21:52:05 +0000 Subject: [PATCH] fix: Float8Tensor as_strided workaround Float8Tensor (torchao) does not implement aten.as_strided which is used internally by unflatten/view/reshape. Add safe_unflatten helper that detects Float8Tensor and casts to float32 before the operation. Applied to causal_vace_model.py where the error occurs in DenoiseBlock. Fixes daydreamlive/scope#501 --- src/scope/core/pipelines/wan2_1/utils.py | 19 +++++++++++++++++++ .../wan2_1/vace/models/causal_vace_model.py | 12 +++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/scope/core/pipelines/wan2_1/utils.py b/src/scope/core/pipelines/wan2_1/utils.py index 6f98c18bb..7c0e412ee 100644 --- a/src/scope/core/pipelines/wan2_1/utils.py +++ b/src/scope/core/pipelines/wan2_1/utils.py @@ -131,3 +131,22 @@ def load_state_dict(weights_path: str) -> dict: ) return state_dict + + +def _is_float8_tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a Float8Tensor (from torchao) that lacks as_strided support.""" + return type(tensor).__name__ == "Float8Tensor" + + +def safe_unflatten(tensor: torch.Tensor, dim: int, sizes) -> torch.Tensor: + """unflatten that works with Float8Tensor by casting to float if needed. + + Float8Tensor (torchao) does not implement aten.as_strided which is used + internally by unflatten/view/reshape. We cast to the tensor's nominal + dtype, perform the operation, and return the plain tensor. The surrounding + code already expects real-valued tensors for subsequent math, so this is + safe. + """ + if _is_float8_tensor(tensor): + tensor = tensor.to(torch.float32) + return tensor.unflatten(dim, sizes) diff --git a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py index 96ed07d29..fc4a50c8e 100644 --- a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py +++ b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py @@ -7,6 +7,8 @@ import torch import torch.nn as nn +from scope.core.pipelines.wan2_1.utils import safe_unflatten + from .attention_blocks import ( create_base_attention_block_class, create_vace_attention_block_class, @@ -403,10 +405,10 @@ def _forward_inference( self.causal_wan_model.freq_dim, t.flatten() ).type_as(x) ) - e0 = ( - self.causal_wan_model.time_projection(e) - .unflatten(1, (6, self.dim)) - .unflatten(dim=0, sizes=t.shape) + e0 = safe_unflatten( + safe_unflatten(self.causal_wan_model.time_projection(e), 1, (6, self.dim)), + 0, + t.shape, ) # Context @@ -510,7 +512,7 @@ def custom_forward(*inputs, **kwargs): ) x = self.causal_wan_model.head( - x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2) + x, safe_unflatten(e, 0, t.shape).unsqueeze(2) ) x = self.causal_wan_model.unpatchify(x, grid_sizes) return torch.stack(x)