Skip to content

adds segment ids for masking. #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 83 additions & 17 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,24 @@ def _unflatten_heads(tensor, heads):
return tensor


def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
def _reshape_data_for_flash(tensor, heads):
"""
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
blocks is divisible by the number of shards.
"""
if tensor.ndim != 4:
tensor = _unflatten_heads(tensor, heads)
return tensor


def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
"""
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
blocks is divisible by the number of shards.
"""
tensor = _reshape_data_for_flash(tensor, heads)

# Pad head_dim to 128 if less than that.
kv_size = tensor.shape[-1]
Expand Down Expand Up @@ -148,8 +158,7 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1

if kv_size < 128 or seq_len_pad != 0:
npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad))
padded_tensor = jnp.pad(tensor, npad)
tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "tensor", "fsdp", None))
tensor = jnp.pad(tensor, npad)

return tensor, kv_size, seq_len

Expand All @@ -164,12 +173,14 @@ def _tpu_flash_attention(
axis_names_kv: AxisNames,
flash_block_sizes: BlockSizes,
dtype: jnp.dtype = jnp.float32,
attention_kernel: str = "flash",
) -> jax.Array:
"""TPU Flash Attention"""

q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
# Cross-attention where kv dims are much smaller due to encoder_hidden_states.
# If kv seq_len is padded too much, it causes issues in attention calculations.
# This is the case for cross-attn.
if key.shape[1] != query.shape[1]:
assert key.shape[1] % 128 == 0
kv_max_block_size = key.shape[1]
else:
kv_max_block_size = q_max_block_size
Expand All @@ -186,38 +197,90 @@ def _tpu_flash_attention(
block_q_dq=min(q_max_block_size, query.shape[2]),
block_kv_dq=min(kv_max_block_size, query.shape[2]),
)

num_fsdp_shards = mesh.shape["fsdp"]
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards)
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards)
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
query = _reshape_data_for_flash(query, heads)
key = _reshape_data_for_flash(key, heads)
value = _reshape_data_for_flash(value, heads)
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)

@functools.partial(
shard_map.shard_map,
mesh=mesh,
in_specs=(
q_axis_names,
kv_axis_names,
kv_axis_names,
),
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
out_specs=q_axis_names,
check_rep=False,
)
def wrap_flash_attention(query, key, value):

query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q)
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv_compute)
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv_compute)

mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])

q_padded_len = query.shape[2]
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)

kv_padded_len = key.shape[2]
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)

# make_splash_mha is wrapped around shardmap and seq and head is already
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1, # the sizes of the axis is sharding over heads
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
block_sizes=block_sizes,
save_residuals=True if attention_kernel == "ring" else False,
)
attention_output = jax.vmap(splash_kernel)(query, key, value)
return attention_output
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))

if attention_kernel == "flash":
attention_output = vmapped_splash(query, key, value, segment_ids)
else:
if num_fsdp_shards > 1:
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
m = lse.astype(jnp.float32)
l = jnp.exp(lse - m)
o = out.astype(jnp.float32) * l[..., None]

perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]

k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)

def ring_scan_body(carry, _):
m, l, o, k_current, v_current = carry
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)

out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)

m_chunk = lse_chunk.astype(jnp.float32)
m_old = m
m = jnp.maximum(m_old, m_chunk)

exp_m_diff = jnp.exp(m_old - m)
exp_m_chunk_diff = jnp.exp(m_chunk - m)

l = l * exp_m_diff + jnp.exp(lse_chunk - m)
o = o * exp_m_diff[..., None]
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)

# Return the updated state for the next iteration
return (m, l, o, k_next, v_next), None

initial_carry = (m, l, o, k1, v1)
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)

attention_output = o_final / l_final[..., None]

return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)

devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
# This warning might show up when doing model eval for example, when calculating model flops
Expand All @@ -228,7 +291,6 @@ def wrap_flash_attention(query, key, value):
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
)
x = wrap_flash_attention(query, key, value)
x = x[:, :, :query_seq_len, :kv_size]
x = _reshape_heads_to_head_dim(x)

return x
Expand Down Expand Up @@ -379,6 +441,10 @@ def _apply_attention(
return _tpu_flash_attention(
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
)
elif attention_kernel == "ring":
return _tpu_flash_attention(
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel
)
elif attention_kernel == "cudnn_flash_te":
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
else:
Expand Down
12 changes: 12 additions & 0 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from . import max_logging
from . import max_utils
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
from maxdiffusion.common_types import LENGTH, KV_LENGTH


def string_to_bool(s: str) -> bool:
Expand Down Expand Up @@ -175,6 +176,17 @@ def user_init(raw_keys):
max_utils.write_config_raw_keys_for_gcs(raw_keys)

raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
# Verify qkv is sharded across sequence.
if raw_keys["attention"] == "ring":
logical_axis_rules = list(raw_keys["logical_axis_rules"])
q_seq_sharding = (LENGTH, "fsdp")
kv_seq_sharding = (KV_LENGTH, "fsdp")
if q_seq_sharding not in logical_axis_rules:
logical_axis_rules.append(q_seq_sharding)
if kv_seq_sharding not in logical_axis_rules:
logical_axis_rules.append(kv_seq_sharding)
raw_keys["logical_axis_rules"] = tuple(logical_axis_rules)

raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"])

if raw_keys["learning_rate_schedule_steps"] == -1:
Expand Down
67 changes: 34 additions & 33 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
eval_data_iterator = self.load_dataset(mesh, is_training=False)
eval_rng = jax.random.key(self.config.seed + step)
eval_metrics = []
# Loop indefinitely until the iterator is exhausted
# Loop indefinitely until the iterator is exhausted
while True:
try:
with mesh:
Expand Down Expand Up @@ -329,6 +329,7 @@ def loss_fn(params):
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
return new_state, scheduler_state, metrics, new_rng


def eval_step(state, data, rng, scheduler_state, scheduler, config):
"""
Computes the evaluation loss for a single batch without updating model weights.
Expand All @@ -338,44 +339,44 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
# This ensures the batch size is consistent, though it might be redundant
# if the evaluation dataloader is already configured correctly.
for k, v in data.items():
data[k] = v[: config.global_batch_size_to_train_on, :]
data[k] = v[: config.global_batch_size_to_train_on, :]

# The loss function logic is identical to training. We are evaluating the model's
# ability to perform its core training objective (e.g., denoising).
def loss_fn(params):
# Reconstruct the model from its definition and parameters
model = nnx.merge(state.graphdef, params, state.rest_of_state)

# Prepare inputs
latents = data["latents"].astype(config.weights_dtype)
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
bsz = latents.shape[0]

# Sample random timesteps and noise, just as in a training step
timesteps = jax.random.randint(
timestep_rng,
(bsz,),
0,
scheduler.config.num_train_timesteps,
)
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)

# Get the model's prediction
model_pred = model(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
)
# Reconstruct the model from its definition and parameters
model = nnx.merge(state.graphdef, params, state.rest_of_state)

# Prepare inputs
latents = data["latents"].astype(config.weights_dtype)
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
bsz = latents.shape[0]

# Sample random timesteps and noise, just as in a training step
timesteps = jax.random.randint(
timestep_rng,
(bsz,),
0,
scheduler.config.num_train_timesteps,
)
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)

# Calculate the loss against the target
training_target = scheduler.training_target(latents, noise, timesteps)
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
loss = (training_target - model_pred) ** 2
loss = loss * training_weight
loss = jnp.mean(loss)
# Get the model's prediction
model_pred = model(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
)

return loss
# Calculate the loss against the target
training_target = scheduler.training_target(latents, noise, timesteps)
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
loss = (training_target - model_pred) ** 2
loss = loss * training_weight
loss = jnp.mean(loss)

return loss

# --- Key Difference from train_step ---
# Directly compute the loss without calculating gradients.
Expand Down
Loading