Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/scope/core/pipelines/wan2_1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,22 @@ def load_state_dict(weights_path: str) -> dict:
)

return state_dict


def _is_float8_tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Float8Tensor (from torchao) that lacks as_strided support."""
return type(tensor).__name__ == "Float8Tensor"


def safe_unflatten(tensor: torch.Tensor, dim: int, sizes) -> torch.Tensor:
"""unflatten that works with Float8Tensor by casting to float if needed.

Float8Tensor (torchao) does not implement aten.as_strided which is used
internally by unflatten/view/reshape. We cast to the tensor's nominal
dtype, perform the operation, and return the plain tensor. The surrounding
code already expects real-valued tensors for subsequent math, so this is
safe.
"""
if _is_float8_tensor(tensor):
tensor = tensor.to(torch.float32)
return tensor.unflatten(dim, sizes)
12 changes: 7 additions & 5 deletions src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
import torch.nn as nn

from scope.core.pipelines.wan2_1.utils import safe_unflatten

from .attention_blocks import (
create_base_attention_block_class,
create_vace_attention_block_class,
Expand Down Expand Up @@ -403,10 +405,10 @@ def _forward_inference(
self.causal_wan_model.freq_dim, t.flatten()
).type_as(x)
)
e0 = (
self.causal_wan_model.time_projection(e)
.unflatten(1, (6, self.dim))
.unflatten(dim=0, sizes=t.shape)
e0 = safe_unflatten(
safe_unflatten(self.causal_wan_model.time_projection(e), 1, (6, self.dim)),
0,
t.shape,
)

# Context
Expand Down Expand Up @@ -510,7 +512,7 @@ def custom_forward(*inputs, **kwargs):
)

x = self.causal_wan_model.head(
x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2)
x, safe_unflatten(e, 0, t.shape).unsqueeze(2)
)
x = self.causal_wan_model.unpatchify(x, grid_sizes)
return torch.stack(x)
Expand Down