diff --git a/src/scope/core/pipelines/wan2_1/utils.py b/src/scope/core/pipelines/wan2_1/utils.py index 6f98c18bb..7c0e412ee 100644 --- a/src/scope/core/pipelines/wan2_1/utils.py +++ b/src/scope/core/pipelines/wan2_1/utils.py @@ -131,3 +131,22 @@ def load_state_dict(weights_path: str) -> dict: ) return state_dict + + +def _is_float8_tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a Float8Tensor (from torchao) that lacks as_strided support.""" + return type(tensor).__name__ == "Float8Tensor" + + +def safe_unflatten(tensor: torch.Tensor, dim: int, sizes) -> torch.Tensor: + """unflatten that works with Float8Tensor by casting to float if needed. + + Float8Tensor (torchao) does not implement aten.as_strided which is used + internally by unflatten/view/reshape. We cast to the tensor's nominal + dtype, perform the operation, and return the plain tensor. The surrounding + code already expects real-valued tensors for subsequent math, so this is + safe. + """ + if _is_float8_tensor(tensor): + tensor = tensor.to(torch.float32) + return tensor.unflatten(dim, sizes) diff --git a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py index 96ed07d29..fc4a50c8e 100644 --- a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py +++ b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py @@ -7,6 +7,8 @@ import torch import torch.nn as nn +from scope.core.pipelines.wan2_1.utils import safe_unflatten + from .attention_blocks import ( create_base_attention_block_class, create_vace_attention_block_class, @@ -403,10 +405,10 @@ def _forward_inference( self.causal_wan_model.freq_dim, t.flatten() ).type_as(x) ) - e0 = ( - self.causal_wan_model.time_projection(e) - .unflatten(1, (6, self.dim)) - .unflatten(dim=0, sizes=t.shape) + e0 = safe_unflatten( + safe_unflatten(self.causal_wan_model.time_projection(e), 1, (6, self.dim)), + 0, + t.shape, ) # Context @@ -510,7 +512,7 @@ def custom_forward(*inputs, **kwargs): ) x = self.causal_wan_model.head( - x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2) + x, safe_unflatten(e, 0, t.shape).unsqueeze(2) ) x = self.causal_wan_model.unpatchify(x, grid_sizes) return torch.stack(x)