Skip to content

Commit 6f7583b

Browse files
committed
Merge branch 'wan_vae_dp' into dev
2 parents d8b3818 + 7a5fab1 commit 6f7583b

File tree

1 file changed

+67
-47
lines changed

1 file changed

+67
-47
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn as nn
1919
import torch.nn.functional as F
2020
import torch.distributed as dist
21+
import torch.distributed as dist
2122

2223
from ...configuration_utils import ConfigMixin, register_to_config
2324
from ...loaders import FromOriginalModelMixin
@@ -1076,6 +1077,8 @@ def __init__(
10761077

10771078
self.use_dp = False
10781079

1080+
self.use_dp = False
1081+
10791082
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
10801083
self._cached_conv_counts = {
10811084
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
@@ -1163,6 +1166,53 @@ def enable_dp(
11631166
self.num_tiles_per_rank[rank_idx] += 1
11641167
rank_idx += 1
11651168

1169+
def enable_dp(
1170+
self,
1171+
world_size: Optional[int] = None,
1172+
hw_splits: Optional[Tuple[int, int]] = None,
1173+
overlap_ratio: Optional[float] = None,
1174+
overlap_pixels: Optional[int] = None
1175+
) -> None:
1176+
r"""
1177+
"""
1178+
if world_size is None:
1179+
world_size = dist.get_world_size()
1180+
1181+
if world_size <= 1 or world_size > dist.get_world_size():
1182+
return
1183+
1184+
if hw_splits is None:
1185+
hw_splits = (1, int(world_size))
1186+
1187+
assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}"
1188+
1189+
h_split, w_split = map(int, hw_splits)
1190+
num_tiles = h_split * w_split
1191+
1192+
# assert h_split * w_split == world_size, \
1193+
# (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}")
1194+
1195+
self.use_dp = True
1196+
self.h_split, self.w_split = h_split, w_split
1197+
self.world_size = world_size
1198+
self.overlap_ratio = overlap_ratio
1199+
self.overlap_pixels = overlap_pixels
1200+
1201+
dp_ranks = list(range(0, world_size))
1202+
self.vae_dp_group = dist.new_group(ranks=dp_ranks)
1203+
self.rank = dist.get_rank()
1204+
# patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)]
1205+
# self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split)
1206+
self.tile_idxs_per_rank = [[] for _ in range(self.world_size)]
1207+
self.num_tiles_per_rank = [0] * self.world_size
1208+
rank_idx = 0
1209+
for h_idx in range(self.h_split):
1210+
for w_idx in range(self.w_split):
1211+
rank_idx %= self.world_size
1212+
self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx))
1213+
self.num_tiles_per_rank[rank_idx] += 1
1214+
rank_idx += 1
1215+
11661216
def clear_cache(self):
11671217
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
11681218
self._conv_num = self._cached_conv_counts["decoder"]
@@ -1232,6 +1282,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
12321282
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
12331283
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
12341284

1285+
if self.use_dp:
1286+
return self.tiled_decode_with_dp(z, return_dict=return_dict)
1287+
12351288
if self.use_dp:
12361289
return self.tiled_decode_with_dp(z, return_dict=return_dict)
12371290

@@ -1470,8 +1523,8 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni
14701523
tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split)
14711524

14721525
# Calculate overlap in latent space
1473-
overlap_latent_height = 2
1474-
overlap_latent_width = 2
1526+
overlap_latent_height = 3
1527+
overlap_latent_width = 3
14751528
if self.overlap_pixels is not None:
14761529
overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio
14771530
overlap_latent_height = overlap_latent
@@ -1512,12 +1565,11 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni
15121565
num_tile_rows = self.h_split
15131566
num_tile_cols = self.w_split
15141567

1515-
# Each rank computes only tiles assigned to it based on patch_ranks
1516-
# local_tiles = {} # Dictionary to store tiles computed by this rank: {(i_idx, j_idx): tile_tensor}
1568+
# Each rank computes only tiles assigned to it based on tile_idxs_per_rank
1569+
# local_tiles = [] # List to store tiles computed by this rank
15171570
local_tiles = []
15181571
local_hw_shapes = []
15191572

1520-
# h_idxs, w_idxs = torch.where(self.patch_ranks == self.rank)
15211573
for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]:
15221574
self.clear_cache()
15231575
patch_height_start = h_idx * tile_latent_stride_height
@@ -1534,17 +1586,20 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni
15341586
)
15351587
time.append(decoded)
15361588
time = torch.cat(time, dim=2)
1537-
local_tiles.append(time.flatten(3, 4))
1538-
local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int())
1589+
local_tiles.append(time.flatten(3, 4)) # flatten h,w dim for concate all tiles in one rank
1590+
local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int()) # record hw for futher unflatten
15391591
self.clear_cache()
15401592

1593+
# concat all tiles on local rank
15411594
local_tiles = torch.cat(local_tiles, dim=3)
15421595
local_hw_shapes = torch.stack(local_hw_shapes)
15431596

1597+
# get all hw shapes for each rank (perhaps has different shapes for last tile)
15441598
gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device)
15451599
for num_tiles in self.num_tiles_per_rank]
15461600
dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group)
15471601

1602+
# gather tiles on all ranks
15481603
b, c, n = local_tiles.shape[:3]
15491604
gathered_tiles = [
15501605
torch.empty(
@@ -1553,55 +1608,20 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni
15531608
]
15541609
dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group)
15551610

1611+
# put tiles in rows based on tile_idxs_per_rank
15561612
rows = [[None] * num_tile_cols for _ in range(num_tile_rows)]
15571613
for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank):
15581614
rank_tile_hw_shapes = gathered_shape_list[rank_idx]
15591615
hw_start_idx = 0
1616+
# perhaps has more than one tile in each rank, get each by hw_shapes
15601617
for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs):
15611618
rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx]
1562-
hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item()
1619+
hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw
15631620
rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten(
1564-
3, rank_tile_hw_shape.tolist())
1621+
3, rank_tile_hw_shape.tolist()) # unflatten hw dim
15651622
hw_start_idx = hw_end_idx
15661623

1567-
1568-
# # Gather all tiles from all ranks
1569-
# # Prepare data for all_gather: each rank sends a dictionary of (position, tile) pairs
1570-
# gathered_tiles_list = [None] * self.world_size
1571-
# dist.all_gather_object(gathered_tiles_list, local_tiles, group=self.vae_dp_group)
1572-
1573-
# # Reconstruct the full rows structure from gathered tiles
1574-
# # First, find a reference tile to determine the expected tile shape
1575-
# all_tiles = {}
1576-
# reference_shape = None
1577-
# for rank_tiles in gathered_tiles_list:
1578-
# if rank_tiles is not None and len(rank_tiles) > 0:
1579-
# all_tiles.update(rank_tiles)
1580-
# if reference_shape is None:
1581-
# reference_shape = list(rank_tiles.values())[0].shape
1582-
# del gathered_tiles_list
1583-
1584-
# rows = []
1585-
# for i_idx in range(num_tile_rows):
1586-
# row = []
1587-
# for j_idx in range(num_tile_cols):
1588-
# # Find the tile at position (i_idx, j_idx) from gathered results
1589-
# tile = all_tiles.get((i_idx, j_idx), None)
1590-
# # If tile not found (shouldn't happen if world_size matches tile count), use reference shape
1591-
# if tile is None:
1592-
# if reference_shape is not None:
1593-
# # Use reference shape but ensure it's on the correct device
1594-
# tile = torch.zeros(*reference_shape, device=z.device, dtype=z.dtype)
1595-
# else:
1596-
# # Fallback: estimate shape (shouldn't happen in normal operation)
1597-
# batch_size, channels = z.shape[:2]
1598-
# estimated_h = tile_sample_min_height // (self.config.patch_size if self.config.patch_size is not None else 1)
1599-
# estimated_w = tile_sample_min_width // (self.config.patch_size if self.config.patch_size is not None else 1)
1600-
# tile = torch.zeros(batch_size, channels, num_frames, estimated_h, estimated_w,
1601-
# device=z.device, dtype=z.dtype)
1602-
# row.append(tile)
1603-
# rows.append(row)
1604-
1624+
# combine all tiles, same as tiled decode
16051625
result_rows = []
16061626
for i, row in enumerate(rows):
16071627
result_row = []

0 commit comments

Comments
 (0)