From f93d2610adf72aec8d18d9692413a29e0194c9f8 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 20 Aug 2025 20:21:57 +0000 Subject: [PATCH 1/4] adds segment ids for masking. --- src/maxdiffusion/models/attention_flax.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index fcdb7cf6..851b5dac 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -189,34 +189,39 @@ def _tpu_flash_attention( num_fsdp_shards = mesh.shape["fsdp"] query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards) - key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards) + key, _, key_seq_len = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards) value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards) q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) + # To only attend to non-padded tokens. + segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, LENGTH)) + segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, KV_LENGTH)) + q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0) + q_segment_ids = jnp.broadcast_to(q_segment_ids, (query.shape[0], q_segment_ids.shape[0])) + kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0) + kv_segment_ids = jnp.broadcast_to(kv_segment_ids, (query.shape[0], kv_segment_ids.shape[0])) + @functools.partial( shard_map.shard_map, mesh=mesh, - in_specs=( - q_axis_names, - kv_axis_names, - kv_axis_names, - ), + in_specs=(q_axis_names, kv_axis_names, kv_axis_names, segment_axis_names_q, segment_axis_names_kv), out_specs=q_axis_names, check_rep=False, ) - def wrap_flash_attention(query, key, value): + def wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids): mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. + segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, head_shards=1, # the sizes of the axis is sharding over heads q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, ) - attention_output = jax.vmap(splash_kernel)(query, key, value) + attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=segment_ids) return attention_output devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] @@ -227,7 +232,7 @@ def wrap_flash_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" ) - x = wrap_flash_attention(query, key, value) + x = wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids) x = x[:, :, :query_seq_len, :kv_size] x = _reshape_heads_to_head_dim(x) From c7edfb041903b9b74b57a753d94224804a4ae0f3 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 21 Aug 2025 00:50:53 +0000 Subject: [PATCH 2/4] reduce padding by computing it inside sharded qkvs. --- src/maxdiffusion/models/attention_flax.py | 59 +++++++++++++---------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 851b5dac..6e689d5c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -112,8 +112,7 @@ def _unflatten_heads(tensor, heads): tensor = jnp.transpose(tensor, (0, 2, 1, 3)) return tensor - -def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): +def _reshape_data_for_flash(tensor, heads): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of @@ -121,6 +120,15 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1 """ if tensor.ndim != 4: tensor = _unflatten_heads(tensor, heads) + return tensor + +def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): + """ + Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. + Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of + blocks is divisible by the number of shards. + """ + tensor = _reshape_data_for_flash(tensor, heads) # Pad head_dim to 128 if less than that. kv_size = tensor.shape[-1] @@ -148,8 +156,7 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1 if kv_size < 128 or seq_len_pad != 0: npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) - padded_tensor = jnp.pad(tensor, npad) - tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "tensor", "fsdp", None)) + tensor = jnp.pad(tensor, npad) return tensor, kv_size, seq_len @@ -166,11 +173,13 @@ def _tpu_flash_attention( dtype: jnp.dtype = jnp.float32, ) -> jax.Array: """TPU Flash Attention""" + q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - # Cross-attention where kv dims are much smaller due to encoder_hidden_states. - # If kv seq_len is padded too much, it causes issues in attention calculations. + # This is the case for cross-attn. if key.shape[1] != query.shape[1]: + assert key.shape[1] % 128 == 0 kv_max_block_size = key.shape[1] + #q_max_block_size = kv_max_block_size else: kv_max_block_size = q_max_block_size if flash_block_sizes: @@ -186,35 +195,36 @@ def _tpu_flash_attention( block_q_dq=min(q_max_block_size, query.shape[2]), block_kv_dq=min(kv_max_block_size, query.shape[2]), ) - - num_fsdp_shards = mesh.shape["fsdp"] - query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards) - key, _, key_seq_len = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards) - value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards) + + query = _reshape_data_for_flash(query, heads) + key = _reshape_data_for_flash(key, heads) + value = _reshape_data_for_flash(value, heads) q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) - # To only attend to non-padded tokens. - segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, LENGTH)) - segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, KV_LENGTH)) - q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0) - q_segment_ids = jnp.broadcast_to(q_segment_ids, (query.shape[0], q_segment_ids.shape[0])) - kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0) - kv_segment_ids = jnp.broadcast_to(kv_segment_ids, (query.shape[0], kv_segment_ids.shape[0])) - @functools.partial( shard_map.shard_map, mesh=mesh, - in_specs=(q_axis_names, kv_axis_names, kv_axis_names, segment_axis_names_q, segment_axis_names_kv), + in_specs=(q_axis_names, kv_axis_names, kv_axis_names), out_specs=q_axis_names, check_rep=False, ) - def wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids): + def wrap_flash_attention(query, key, value): + + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv_compute) + value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv_compute) + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0) + q_segment_ids = jnp.broadcast_to(q_segment_ids, (query.shape[0], q_segment_ids.shape[0])) + kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0) + kv_segment_ids = jnp.broadcast_to(kv_segment_ids, (query.shape[0], kv_segment_ids.shape[0])) + segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. - segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, head_shards=1, # the sizes of the axis is sharding over heads @@ -222,7 +232,7 @@ def wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids): block_sizes=block_sizes, ) attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=segment_ids) - return attention_output + return attention_output[:,:,:query_seq_len,:kv_size] devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] # This warning might show up when doing model eval for example, when calculating model flops @@ -232,8 +242,7 @@ def wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" ) - x = wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids) - x = x[:, :, :query_seq_len, :kv_size] + x = wrap_flash_attention(query, key, value) x = _reshape_heads_to_head_dim(x) return x From ec61456ddfac912329f6e5091b7c6e6c81251c34 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 21 Aug 2025 06:31:19 +0000 Subject: [PATCH 3/4] scanned ring attn. --- src/maxdiffusion/models/attention_flax.py | 50 ++++++++++++++++++++--- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 6e689d5c..48709665 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -195,7 +195,7 @@ def _tpu_flash_attention( block_q_dq=min(q_max_block_size, query.shape[2]), block_kv_dq=min(kv_max_block_size, query.shape[2]), ) - + num_fsdp_shards = mesh.shape["fsdp"] query = _reshape_data_for_flash(query, heads) key = _reshape_data_for_flash(key, heads) value = _reshape_data_for_flash(value, heads) @@ -218,9 +218,7 @@ def wrap_flash_attention(query, key, value): mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0) - q_segment_ids = jnp.broadcast_to(q_segment_ids, (query.shape[0], q_segment_ids.shape[0])) kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0) - kv_segment_ids = jnp.broadcast_to(kv_segment_ids, (query.shape[0], kv_segment_ids.shape[0])) segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) # make_splash_mha is wrapped around shardmap and seq and head is already @@ -230,9 +228,51 @@ def wrap_flash_attention(query, key, value): head_shards=1, # the sizes of the axis is sharding over heads q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, + save_residuals=True ) - attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=segment_ids) - return attention_output[:,:,:query_seq_len,:kv_size] + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0,0,0, None)) + + def ring_scan_body(carry, _): + m, l, o, k_current, v_current = carry + perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] + k_next = jax.lax.ppermute(k_current, axis_name='fsdp', perm=perm) + v_next = jax.lax.ppermute(v_current, axis_name='fsdp', perm=perm) + + out_chunk, (lse_chunk,) = vmapped_splash( + query, k_current, v_current, segment_ids + ) + + m_chunk = lse_chunk.astype(jnp.float32) + m_old = m + m = jnp.maximum(m_old, m_chunk) + + exp_m_diff = jnp.exp(m_old - m) + exp_m_chunk_diff = jnp.exp(m_chunk - m) + + l = l * exp_m_diff + jnp.exp(lse_chunk - m) + o = o * exp_m_diff[..., None] + o += (exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)) + + # Return the updated state for the next iteration + return (m, l, o, k_next, v_next), None + + lse_shape = query.shape[:-1] + m_init = jnp.full(lse_shape, -jnp.inf, dtype=jnp.float32) + l_init = jnp.zeros(lse_shape, dtype=jnp.float32) + o_init = jnp.zeros_like(query, dtype=jnp.float32) + + initial_carry = (m_init, l_init, o_init, key, value) + + (m_final, l_final, o_final, _, _), _ = jax.lax.scan( + ring_scan_body, + initial_carry, + None, + length=num_fsdp_shards + ) + + attention_output = o_final / l_final[..., None] + + return attention_output[:,:,:query_seq_len,:kv_size].astype(query.dtype) devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] # This warning might show up when doing model eval for example, when calculating model flops From 56447408d8c049dc7216ed274be7af158934a4f1 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 21 Aug 2025 22:25:11 +0000 Subject: [PATCH 4/4] add ring attention - inference only. --- src/maxdiffusion/models/attention_flax.py | 86 +++++++++++++---------- src/maxdiffusion/pyconfig.py | 12 ++++ src/maxdiffusion/trainers/wan_trainer.py | 67 +++++++++--------- 3 files changed, 95 insertions(+), 70 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 48709665..25788fb6 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -112,6 +112,7 @@ def _unflatten_heads(tensor, heads): tensor = jnp.transpose(tensor, (0, 2, 1, 3)) return tensor + def _reshape_data_for_flash(tensor, heads): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. @@ -122,6 +123,7 @@ def _reshape_data_for_flash(tensor, heads): tensor = _unflatten_heads(tensor, heads) return tensor + def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. @@ -171,6 +173,7 @@ def _tpu_flash_attention( axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dtype: jnp.dtype = jnp.float32, + attention_kernel: str = "flash", ) -> jax.Array: """TPU Flash Attention""" @@ -179,7 +182,6 @@ def _tpu_flash_attention( if key.shape[1] != query.shape[1]: assert key.shape[1] % 128 == 0 kv_max_block_size = key.shape[1] - #q_max_block_size = kv_max_block_size else: kv_max_block_size = q_max_block_size if flash_block_sizes: @@ -217,8 +219,14 @@ def wrap_flash_attention(query, key, value): mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) - q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0) - kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0) + + q_padded_len = query.shape[2] + q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) + q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) + + kv_padded_len = key.shape[2] + kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) + kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) # make_splash_mha is wrapped around shardmap and seq and head is already @@ -228,51 +236,51 @@ def wrap_flash_attention(query, key, value): head_shards=1, # the sizes of the axis is sharding over heads q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, - save_residuals=True + save_residuals=True if attention_kernel == "ring" else False, ) - vmapped_splash = jax.vmap(splash_kernel, in_axes=(0,0,0, None)) + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) - def ring_scan_body(carry, _): - m, l, o, k_current, v_current = carry - perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] - k_next = jax.lax.ppermute(k_current, axis_name='fsdp', perm=perm) - v_next = jax.lax.ppermute(v_current, axis_name='fsdp', perm=perm) + if attention_kernel == "flash": + attention_output = vmapped_splash(query, key, value, segment_ids) + else: + if num_fsdp_shards > 1: + out, (lse,) = vmapped_splash(query, key, value, segment_ids) + m = lse.astype(jnp.float32) + l = jnp.exp(lse - m) + o = out.astype(jnp.float32) * l[..., None] - out_chunk, (lse_chunk,) = vmapped_splash( - query, k_current, v_current, segment_ids - ) + perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] - m_chunk = lse_chunk.astype(jnp.float32) - m_old = m - m = jnp.maximum(m_old, m_chunk) - - exp_m_diff = jnp.exp(m_old - m) - exp_m_chunk_diff = jnp.exp(m_chunk - m) + k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) + v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) - l = l * exp_m_diff + jnp.exp(lse_chunk - m) - o = o * exp_m_diff[..., None] - o += (exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)) + def ring_scan_body(carry, _): + m, l, o, k_current, v_current = carry + k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) + v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) - # Return the updated state for the next iteration - return (m, l, o, k_next, v_next), None + out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) - lse_shape = query.shape[:-1] - m_init = jnp.full(lse_shape, -jnp.inf, dtype=jnp.float32) - l_init = jnp.zeros(lse_shape, dtype=jnp.float32) - o_init = jnp.zeros_like(query, dtype=jnp.float32) + m_chunk = lse_chunk.astype(jnp.float32) + m_old = m + m = jnp.maximum(m_old, m_chunk) - initial_carry = (m_init, l_init, o_init, key, value) + exp_m_diff = jnp.exp(m_old - m) + exp_m_chunk_diff = jnp.exp(m_chunk - m) - (m_final, l_final, o_final, _, _), _ = jax.lax.scan( - ring_scan_body, - initial_carry, - None, - length=num_fsdp_shards - ) + l = l * exp_m_diff + jnp.exp(lse_chunk - m) + o = o * exp_m_diff[..., None] + o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) + + # Return the updated state for the next iteration + return (m, l, o, k_next, v_next), None + + initial_carry = (m, l, o, k1, v1) + (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) - attention_output = o_final / l_final[..., None] + attention_output = o_final / l_final[..., None] - return attention_output[:,:,:query_seq_len,:kv_size].astype(query.dtype) + return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] # This warning might show up when doing model eval for example, when calculating model flops @@ -433,6 +441,10 @@ def _apply_attention( return _tpu_flash_attention( query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype ) + elif attention_kernel == "ring": + return _tpu_flash_attention( + query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel + ) elif attention_kernel == "cudnn_flash_te": return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) else: diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 182a427b..33fc62f8 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -27,6 +27,7 @@ from . import max_logging from . import max_utils from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH +from maxdiffusion.common_types import LENGTH, KV_LENGTH def string_to_bool(s: str) -> bool: @@ -175,6 +176,17 @@ def user_init(raw_keys): max_utils.write_config_raw_keys_for_gcs(raw_keys) raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) + # Verify qkv is sharded across sequence. + if raw_keys["attention"] == "ring": + logical_axis_rules = list(raw_keys["logical_axis_rules"]) + q_seq_sharding = (LENGTH, "fsdp") + kv_seq_sharding = (KV_LENGTH, "fsdp") + if q_seq_sharding not in logical_axis_rules: + logical_axis_rules.append(q_seq_sharding) + if kv_seq_sharding not in logical_axis_rules: + logical_axis_rules.append(kv_seq_sharding) + raw_keys["logical_axis_rules"] = tuple(logical_axis_rules) + raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) if raw_keys["learning_rate_schedule_steps"] == -1: diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index a267e065..37615b07 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -255,7 +255,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data eval_data_iterator = self.load_dataset(mesh, is_training=False) eval_rng = jax.random.key(self.config.seed + step) eval_metrics = [] - # Loop indefinitely until the iterator is exhausted + # Loop indefinitely until the iterator is exhausted while True: try: with mesh: @@ -329,6 +329,7 @@ def loss_fn(params): metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} return new_state, scheduler_state, metrics, new_rng + def eval_step(state, data, rng, scheduler_state, scheduler, config): """ Computes the evaluation loss for a single batch without updating model weights. @@ -338,44 +339,44 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): # This ensures the batch size is consistent, though it might be redundant # if the evaluation dataloader is already configured correctly. for k, v in data.items(): - data[k] = v[: config.global_batch_size_to_train_on, :] + data[k] = v[: config.global_batch_size_to_train_on, :] # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). def loss_fn(params): - # Reconstruct the model from its definition and parameters - model = nnx.merge(state.graphdef, params, state.rest_of_state) - - # Prepare inputs - latents = data["latents"].astype(config.weights_dtype) - encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) - bsz = latents.shape[0] - - # Sample random timesteps and noise, just as in a training step - timesteps = jax.random.randint( - timestep_rng, - (bsz,), - 0, - scheduler.config.num_train_timesteps, - ) - noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) - noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) - - # Get the model's prediction - model_pred = model( - hidden_states=noisy_latents, - timestep=timesteps, - encoder_hidden_states=encoder_hidden_states, - ) + # Reconstruct the model from its definition and parameters + model = nnx.merge(state.graphdef, params, state.rest_of_state) + + # Prepare inputs + latents = data["latents"].astype(config.weights_dtype) + encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) + bsz = latents.shape[0] + + # Sample random timesteps and noise, just as in a training step + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + scheduler.config.num_train_timesteps, + ) + noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) + noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) - # Calculate the loss against the target - training_target = scheduler.training_target(latents, noise, timesteps) - training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) - loss = (training_target - model_pred) ** 2 - loss = loss * training_weight - loss = jnp.mean(loss) + # Get the model's prediction + model_pred = model( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + ) - return loss + # Calculate the loss against the target + training_target = scheduler.training_target(latents, noise, timesteps) + training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) + loss = (training_target - model_pred) ** 2 + loss = loss * training_weight + loss = jnp.mean(loss) + + return loss # --- Key Difference from train_step --- # Directly compute the loss without calculating gradients.