1818import torch .nn as nn
1919import torch .nn .functional as F
2020import torch .distributed as dist
21+ import torch .distributed as dist
2122
2223from ...configuration_utils import ConfigMixin , register_to_config
2324from ...loaders import FromOriginalModelMixin
@@ -1076,6 +1077,8 @@ def __init__(
10761077
10771078 self .use_dp = False
10781079
1080+ self .use_dp = False
1081+
10791082 # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
10801083 self ._cached_conv_counts = {
10811084 "decoder" : sum (isinstance (m , WanCausalConv3d ) for m in self .decoder .modules ())
@@ -1163,6 +1166,53 @@ def enable_dp(
11631166 self .num_tiles_per_rank [rank_idx ] += 1
11641167 rank_idx += 1
11651168
1169+ def enable_dp (
1170+ self ,
1171+ world_size : Optional [int ] = None ,
1172+ hw_splits : Optional [Tuple [int , int ]] = None ,
1173+ overlap_ratio : Optional [float ] = None ,
1174+ overlap_pixels : Optional [int ] = None
1175+ ) -> None :
1176+ r"""
1177+ """
1178+ if world_size is None :
1179+ world_size = dist .get_world_size ()
1180+
1181+ if world_size <= 1 or world_size > dist .get_world_size ():
1182+ return
1183+
1184+ if hw_splits is None :
1185+ hw_splits = (1 , int (world_size ))
1186+
1187+ assert len (hw_splits ) == 2 , f"'hw_splits' should be a tuple of 2 int, but got length { len (hw_splits )} "
1188+
1189+ h_split , w_split = map (int , hw_splits )
1190+ num_tiles = h_split * w_split
1191+
1192+ # assert h_split * w_split == world_size, \
1193+ # (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}")
1194+
1195+ self .use_dp = True
1196+ self .h_split , self .w_split = h_split , w_split
1197+ self .world_size = world_size
1198+ self .overlap_ratio = overlap_ratio
1199+ self .overlap_pixels = overlap_pixels
1200+
1201+ dp_ranks = list (range (0 , world_size ))
1202+ self .vae_dp_group = dist .new_group (ranks = dp_ranks )
1203+ self .rank = dist .get_rank ()
1204+ # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)]
1205+ # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split)
1206+ self .tile_idxs_per_rank = [[] for _ in range (self .world_size )]
1207+ self .num_tiles_per_rank = [0 ] * self .world_size
1208+ rank_idx = 0
1209+ for h_idx in range (self .h_split ):
1210+ for w_idx in range (self .w_split ):
1211+ rank_idx %= self .world_size
1212+ self .tile_idxs_per_rank [rank_idx ].append ((h_idx , w_idx ))
1213+ self .num_tiles_per_rank [rank_idx ] += 1
1214+ rank_idx += 1
1215+
11661216 def clear_cache (self ):
11671217 # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
11681218 self ._conv_num = self ._cached_conv_counts ["decoder" ]
@@ -1232,6 +1282,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
12321282 tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
12331283 tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
12341284
1285+ if self .use_dp :
1286+ return self .tiled_decode_with_dp (z , return_dict = return_dict )
1287+
12351288 if self .use_dp :
12361289 return self .tiled_decode_with_dp (z , return_dict = return_dict )
12371290
@@ -1470,8 +1523,8 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni
14701523 tile_latent_stride_width = int ((width + self .w_split - 1 ) / self .w_split )
14711524
14721525 # Calculate overlap in latent space
1473- overlap_latent_height = 2
1474- overlap_latent_width = 2
1526+ overlap_latent_height = 3
1527+ overlap_latent_width = 3
14751528 if self .overlap_pixels is not None :
14761529 overlap_latent = (self .overlap_pixels + self .spatial_compression_ratio - 1 ) // self .spatial_compression_ratio
14771530 overlap_latent_height = overlap_latent
@@ -1512,12 +1565,11 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni
15121565 num_tile_rows = self .h_split
15131566 num_tile_cols = self .w_split
15141567
1515- # Each rank computes only tiles assigned to it based on patch_ranks
1516- # local_tiles = {} # Dictionary to store tiles computed by this rank: {(i_idx, j_idx): tile_tensor}
1568+ # Each rank computes only tiles assigned to it based on tile_idxs_per_rank
1569+ # local_tiles = [] # List to store tiles computed by this rank
15171570 local_tiles = []
15181571 local_hw_shapes = []
15191572
1520- # h_idxs, w_idxs = torch.where(self.patch_ranks == self.rank)
15211573 for h_idx , w_idx in self .tile_idxs_per_rank [self .rank ]:
15221574 self .clear_cache ()
15231575 patch_height_start = h_idx * tile_latent_stride_height
@@ -1534,17 +1586,20 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni
15341586 )
15351587 time .append (decoded )
15361588 time = torch .cat (time , dim = 2 )
1537- local_tiles .append (time .flatten (3 , 4 ))
1538- local_hw_shapes .append (torch .Tensor ([* time .shape [3 :5 ]]).to (device ).int ())
1589+ local_tiles .append (time .flatten (3 , 4 )) # flatten h,w dim for concate all tiles in one rank
1590+ local_hw_shapes .append (torch .Tensor ([* time .shape [3 :5 ]]).to (device ).int ()) # record hw for futher unflatten
15391591 self .clear_cache ()
15401592
1593+ # concat all tiles on local rank
15411594 local_tiles = torch .cat (local_tiles , dim = 3 )
15421595 local_hw_shapes = torch .stack (local_hw_shapes )
15431596
1597+ # get all hw shapes for each rank (perhaps has different shapes for last tile)
15441598 gathered_shape_list = [torch .empty ((num_tiles , 2 ), dtype = local_hw_shapes .dtype , device = device )
15451599 for num_tiles in self .num_tiles_per_rank ]
15461600 dist .all_gather (gathered_shape_list , local_hw_shapes , group = self .vae_dp_group )
15471601
1602+ # gather tiles on all ranks
15481603 b , c , n = local_tiles .shape [:3 ]
15491604 gathered_tiles = [
15501605 torch .empty (
@@ -1553,55 +1608,20 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni
15531608 ]
15541609 dist .all_gather (gathered_tiles , local_tiles , group = self .vae_dp_group )
15551610
1611+ # put tiles in rows based on tile_idxs_per_rank
15561612 rows = [[None ] * num_tile_cols for _ in range (num_tile_rows )]
15571613 for rank_idx , tile_idxs in enumerate (self .tile_idxs_per_rank ):
15581614 rank_tile_hw_shapes = gathered_shape_list [rank_idx ]
15591615 hw_start_idx = 0
1616+ # perhaps has more than one tile in each rank, get each by hw_shapes
15601617 for tile_idx , (h_idx , w_idx ) in enumerate (tile_idxs ):
15611618 rank_tile_hw_shape = rank_tile_hw_shapes [tile_idx ]
1562- hw_end_idx = hw_start_idx + rank_tile_hw_shape .prod ().item ()
1619+ hw_end_idx = hw_start_idx + rank_tile_hw_shape .prod ().item () # flattend hw
15631620 rows [h_idx ][w_idx ] = gathered_tiles [rank_idx ][:, :, :, hw_start_idx :hw_end_idx ].unflatten (
1564- 3 , rank_tile_hw_shape .tolist ())
1621+ 3 , rank_tile_hw_shape .tolist ()) # unflatten hw dim
15651622 hw_start_idx = hw_end_idx
15661623
1567-
1568- # # Gather all tiles from all ranks
1569- # # Prepare data for all_gather: each rank sends a dictionary of (position, tile) pairs
1570- # gathered_tiles_list = [None] * self.world_size
1571- # dist.all_gather_object(gathered_tiles_list, local_tiles, group=self.vae_dp_group)
1572-
1573- # # Reconstruct the full rows structure from gathered tiles
1574- # # First, find a reference tile to determine the expected tile shape
1575- # all_tiles = {}
1576- # reference_shape = None
1577- # for rank_tiles in gathered_tiles_list:
1578- # if rank_tiles is not None and len(rank_tiles) > 0:
1579- # all_tiles.update(rank_tiles)
1580- # if reference_shape is None:
1581- # reference_shape = list(rank_tiles.values())[0].shape
1582- # del gathered_tiles_list
1583-
1584- # rows = []
1585- # for i_idx in range(num_tile_rows):
1586- # row = []
1587- # for j_idx in range(num_tile_cols):
1588- # # Find the tile at position (i_idx, j_idx) from gathered results
1589- # tile = all_tiles.get((i_idx, j_idx), None)
1590- # # If tile not found (shouldn't happen if world_size matches tile count), use reference shape
1591- # if tile is None:
1592- # if reference_shape is not None:
1593- # # Use reference shape but ensure it's on the correct device
1594- # tile = torch.zeros(*reference_shape, device=z.device, dtype=z.dtype)
1595- # else:
1596- # # Fallback: estimate shape (shouldn't happen in normal operation)
1597- # batch_size, channels = z.shape[:2]
1598- # estimated_h = tile_sample_min_height // (self.config.patch_size if self.config.patch_size is not None else 1)
1599- # estimated_w = tile_sample_min_width // (self.config.patch_size if self.config.patch_size is not None else 1)
1600- # tile = torch.zeros(batch_size, channels, num_frames, estimated_h, estimated_w,
1601- # device=z.device, dtype=z.dtype)
1602- # row.append(tile)
1603- # rows.append(row)
1604-
1624+ # combine all tiles, same as tiled decode
16051625 result_rows = []
16061626 for i , row in enumerate (rows ):
16071627 result_row = []
0 commit comments