diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1a7a883c..2756ac23 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -8,12 +8,14 @@
# 'pre-commit run --all'
repos:
-- repo: https://github.com/astral-sh/ruff-pre-commit
- # Ruff version.
- rev: v0.14.1
- hooks:
- # Run the linter.
- - id: ruff-check
- args: [ --fix ]
- # Run the formatter.
- - id: ruff-format
\ No newline at end of file
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ # Ruff version.
+ rev: v0.14.1
+ hooks:
+ # Run the linter.
+ - id: ruff-check
+ args: [--fix]
+ exclude: bonsai/models/wan2/tests
+ # Run the formatter.
+ - id: ruff-format
+ exclude: bonsai/models/wan2/tests
diff --git a/bonsai/models/umt5/README.md b/bonsai/models/umt5/README.md
new file mode 100644
index 00000000..6e97832e
--- /dev/null
+++ b/bonsai/models/umt5/README.md
@@ -0,0 +1,30 @@
+# UMT5 in JAX
+
+This directory contains a pure JAX implementation of the [UMT5 model](https://arxiv.org/abs/2304.09151), using the [Flax NNX](https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html) API.
+
+
+## Model Configuration Support Status
+
+| Model Name | Config Support Status |
+| :--- | :--- |
+| **Dense Models** | |
+| [umt5-small](https://huggingface.co/google/umt5-small) | **β
Supported** |
+| [umt5-base](https://huggingface.co/google/umt5-base) | **β
Supported** |
+| [umt5-xl](https://huggingface.co/google/umt5-xl) | **β
Supported** |
+| [umt5-xxl](https://huggingface.co/google/umt5-xxl) | **β
Supported** |
+
+
+### Running this model
+
+Run UMT5 in action, implemented in [300 lines of code](modeling.py) in JAX.
+
+```sh
+python3 -m bonsai.models.umt5.tests.run_model
+```
+
+
+## How to contribute to this model
+
+We welcome contributions! You can contribute to this model via the following:
+* Add a model config variant from the above `π‘ Not started` to `class UMT5Config` in [modeling.py](modeling.py). Make sure your code is runnable on at least one hardware before creating a PR.
+* Got some hardware? Run [run_model.py](tests/run_model.py) the existing configs above on hardwares marked `β Needs check`. Mark as `β
Runs` or `βοΈ Not supported`.
diff --git a/bonsai/models/umt5/modeling.py b/bonsai/models/umt5/modeling.py
new file mode 100644
index 00000000..0082a281
--- /dev/null
+++ b/bonsai/models/umt5/modeling.py
@@ -0,0 +1,734 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import dataclasses
+import logging
+import math
+from typing import Optional, Union
+
+import jax
+import jax.numpy as jnp
+from flax import nnx
+from jax._src.typing import DTypeLike
+from jax.lax import Precision
+
+ACT_FN = {
+ "gelu": nnx.gelu,
+ "relu": nnx.relu,
+}
+
+
+def fp16_clamp(x: jax.Array):
+ if x.dtype == jnp.float16 and jnp.isinf(x).any():
+ clamp = jnp.finfo(x.dtype).max - 1000
+ x = jax.lax.clamp(x=x, min=-clamp, max=clamp)
+ return x
+
+
+@dataclasses.dataclass
+class UMT5Config:
+ """Configuration for UMT5 model."""
+
+ vocab_size: int = (250112,)
+ d_model: int = (512,)
+ d_kv: int = (64,)
+ d_ff: int = (1024,)
+ num_layers: int = (8,)
+ num_decoder_layers: int = (None,)
+ num_heads: int = (6,)
+ relative_attention_num_buckets: int = (32,)
+ relative_attention_max_distance: int = (128,)
+ dropout_rate: float = (0.1,)
+ layer_norm_epsilon: float = (1e-6,)
+ initializer_factor: float = (1.0,)
+ feed_forward_proj: str = ("gated-gelu",)
+ is_encoder_decoder: bool = (True,)
+ use_cache: bool = (True,)
+ tokenizer_class: str = ("T5Tokenizer",)
+ tie_word_embeddings: bool = (True,)
+ pad_token_id: int = (0,)
+ eos_token_id: int = (1,)
+ decoder_start_token_id: int = (0,)
+ is_decoder: bool = (False,)
+ dtype: DTypeLike = (jnp.float32,)
+
+ def __post_init__(self):
+ self.num_decoder_layers = (
+ self.num_decoder_layers if self.num_decoder_layers is not None else self.num_layers
+ ) # default = symmetry
+
+ act_info = self.feed_forward_proj.split("-")
+ self.dense_act_fn = act_info[-1]
+ self.is_gated_act = act_info[0] == "gated"
+
+ if (len(act_info) > 1 and act_info[0] != "gated") or len(act_info) > 2:
+ raise ValueError(
+ f"`feed_forward_proj`: {self.feed_forward_proj} is not a valid activation function of the dense layer. "
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
+ "'gated-gelu' or 'relu'"
+ )
+
+ if self.dense_act_fn not in ACT_FN:
+ raise ValueError(
+ f"`feed_forward_proj`: {self.feed_forward_proj} is not a valid activation function of the dense layer. "
+ f"Supported activation functions are: {', '.join(ACT_FN.keys())}"
+ )
+
+
+class T5LayerNorm(nnx.Module):
+ def __init__(
+ self,
+ dim: int,
+ *,
+ eps=1e-6,
+ param_dtype: jnp.dtype | None = jnp.float32,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.scale = nnx.Param(jnp.ones(dim), dtype=param_dtype)
+
+ def __call__(self, hidden_states: jax.Array):
+ # RMS normalization: hidden_states / sqrt(mean(hidden_states^2))
+ variance = jnp.mean(hidden_states.astype(jnp.float32) ** 2, axis=-1, keepdims=True)
+ hidden_states = hidden_states * jax.lax.rsqrt(variance + self.eps)
+ weight_dtype = self.scale.get_value().dtype
+ if weight_dtype in [jnp.float16, jnp.bfloat16]:
+ hidden_states = hidden_states.astype(weight_dtype)
+ return self.scale.get_value() * hidden_states
+
+
+class UMT5DenseActDense(nnx.Module):
+ def __init__(self, config: UMT5Config, *, param_dtype: jnp.dtype | None = jnp.float32, rngs: nnx.Rngs):
+ super().__init__()
+ self.param_dtype = param_dtype
+ self.wi = nnx.Linear(
+ config.d_model, config.d_ff, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs
+ )
+ self.wo = nnx.Linear(
+ config.d_ff, config.d_model, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs
+ )
+ self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
+ self.act = ACT_FN[config.dense_act_fn]
+
+ def __call__(self, hidden_states):
+ hidden_states = self.wi(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+class UMT5DenseGatedActDense(nnx.Module):
+ def __init__(self, config: UMT5Config, *, param_dtype: jnp.dtype | None = jnp.float32, rngs: nnx.Rngs):
+ super().__init__()
+ self.param_dtype = param_dtype
+ self.wi_0 = nnx.Linear(
+ config.d_model, config.d_ff, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs
+ )
+ self.wi_1 = nnx.Linear(
+ config.d_model, config.d_ff, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs
+ )
+ self.wo = nnx.Linear(
+ config.d_ff, config.d_model, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs
+ )
+ self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
+ self.act = ACT_FN[config.dense_act_fn]
+
+ def __call__(self, hidden_states: jax.Array):
+ hidden_gelu = self.act(self.wi_0(hidden_states))
+ hidden_linear = self.wi_1(hidden_states)
+ hidden_states = hidden_gelu * hidden_linear
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+class UMT5LayerFF(nnx.Module):
+ def __init__(self, config: UMT5Config, *, param_dtype: jnp.dtype | None = jnp.float32, rngs: nnx.Rngs):
+ super().__init__()
+ if config.is_gated_act:
+ self.DenseReluDense = UMT5DenseGatedActDense(config, param_dtype=param_dtype, rngs=rngs)
+ else:
+ self.DenseReluDense = UMT5DenseActDense(config, param_dtype=param_dtype, rngs=rngs)
+
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, param_dtype=param_dtype)
+ self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
+
+ def __call__(self, hidden_states: jax.Array):
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.DenseReluDense(forwarded_states)
+ hidden_states = hidden_states + self.dropout(forwarded_states)
+ return hidden_states
+
+
+class UMT5Attention(nnx.Module):
+ """
+ T5's attention using relative_attention_bias.
+ """
+
+ def __init__(
+ self,
+ config: UMT5Config,
+ *,
+ has_relative_attention_bias=False,
+ layer_idx: Optional[int] = None,
+ param_dtype: jnp.dtype | None = jnp.float32,
+ rngs: nnx.Rngs,
+ ):
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
+ self.relative_attention_max_distance = config.relative_attention_max_distance
+ self.d_model = config.d_model
+ self.key_value_proj_dim = config.d_kv
+ self.n_heads = config.num_heads
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ self.q = nnx.Linear(
+ self.d_model,
+ self.inner_dim,
+ precision=Precision.HIGHEST,
+ param_dtype=param_dtype,
+ use_bias=False,
+ rngs=rngs,
+ )
+ self.k = nnx.Linear(
+ self.d_model,
+ self.inner_dim,
+ precision=Precision.HIGHEST,
+ param_dtype=param_dtype,
+ use_bias=False,
+ rngs=rngs,
+ )
+ self.v = nnx.Linear(
+ self.d_model,
+ self.inner_dim,
+ precision=Precision.HIGHEST,
+ param_dtype=param_dtype,
+ use_bias=False,
+ rngs=rngs,
+ )
+ self.o = nnx.Linear(
+ self.inner_dim,
+ self.d_model,
+ precision=Precision.HIGHEST,
+ param_dtype=param_dtype,
+ use_bias=False,
+ rngs=rngs,
+ )
+
+ self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nnx.Embed(
+ self.relative_attention_num_buckets, self.n_heads, param_dtype=param_dtype, rngs=rngs
+ )
+
+ def _relative_position_bucket(self, rel_pos):
+ """Convert relative positions to bucket indices."""
+ if not self.is_decoder:
+ num_buckets = self.relative_attention_num_buckets // 2
+ rel_buckets = (rel_pos > 0).astype(jnp.int32) * num_buckets
+ rel_pos = jnp.abs(rel_pos)
+ else:
+ num_buckets = self.relative_attention_num_buckets
+ rel_buckets = 0
+ rel_pos = -jnp.minimum(rel_pos, jnp.zeros_like(rel_pos))
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = rel_pos < max_exact
+
+ # Logarithmic bucketing for large positions
+ rel_pos_large = max_exact + (
+ jnp.log(rel_pos.astype(jnp.float32) / max_exact)
+ / math.log(self.relative_attention_max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).astype(jnp.int32)
+ rel_pos_large = jnp.minimum(rel_pos_large, num_buckets - 1)
+
+ rel_buckets = rel_buckets + jnp.where(is_small, rel_pos, rel_pos_large)
+ return rel_buckets
+
+ def compute_bias(self, q_len, k_len):
+ """Compute binned relative position bias"""
+ ctx_pos = jnp.arange(q_len, dtype=jnp.int32)[:, None]
+ mem_pos = jnp.arange(k_len, dtype=jnp.int32)[None, :]
+ rel_pos = mem_pos - ctx_pos # shape (query_length, key_length)
+ rel_pos_bkt = self._relative_position_bucket(rel_pos)
+ values = self.relative_attention_bias(rel_pos_bkt) # shape (query_length, key_length, num_heads)
+ # shape (1, num_heads, query_length, key_length)
+ values = values.transpose(2, 0, 1)[None, :, :, :]
+ return values
+
+ def __call__(
+ self,
+ hidden_states: jax.Array,
+ encoder_hidden_states: Optional[jax.Array] = None,
+ attention_mask: Optional[jax.Array] = None,
+ ):
+ b, n, c = hidden_states.shape[0], self.n_heads, self.key_value_proj_dim
+
+ # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder
+ is_cross_attention = encoder_hidden_states is not None
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
+
+ q = self.q(hidden_states).reshape(b, -1, n, c)
+ k = self.k(current_states).reshape(b, -1, n, c)
+ v = self.v(current_states).reshape(b, -1, n, c)
+
+ # Attention bias
+ q_len, k_len = q.shape[1], k.shape[1]
+ if not self.has_relative_attention_bias:
+ position_bias = jnp.zeros((1, n, q_len, k_len), dtype=q.dtype)
+ else:
+ position_bias = self.compute_bias(q_len, k_len)
+ position_bias = position_bias[:, :, -q_len:, :]
+
+ if attention_mask is not None:
+ position_bias = position_bias + attention_mask
+
+ attn = (
+ jnp.einsum(
+ "binc,bjnc->bnij",
+ q,
+ k,
+ precision=Precision.HIGHEST,
+ )
+ + position_bias
+ )
+
+ attn = jax.nn.softmax(attn.astype(jnp.float32), axis=-1).astype(attn.dtype)
+
+ attn = self.dropout(attn)
+
+ o_attn = jnp.einsum(
+ "bnij,bjnc->binc",
+ attn,
+ v,
+ precision=Precision.HIGHEST,
+ )
+
+ o_attn = o_attn.reshape(b, -1, n * c)
+ o_attn = self.o(o_attn)
+ o_attn = self.dropout(o_attn)
+ return o_attn
+
+
+class UMT5LayerSelfAttention(nnx.Module):
+ def __init__(
+ self,
+ config: UMT5Config,
+ *,
+ layer_idx: Optional[int] = None,
+ param_dtype: jnp.dtype | None = jnp.float32,
+ rngs: nnx.Rngs,
+ ):
+ super().__init__()
+ self.SelfAttention = UMT5Attention(
+ config,
+ has_relative_attention_bias=True,
+ layer_idx=layer_idx,
+ param_dtype=param_dtype,
+ rngs=rngs,
+ )
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, param_dtype=param_dtype)
+
+ def __call__(
+ self,
+ hidden_states: jax.Array,
+ attention_mask=None,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.SelfAttention(
+ normed_hidden_states,
+ attention_mask=attention_mask,
+ )
+ outputs = hidden_states + attention_output
+ return outputs
+
+
+class UMT5LayerCrossAttention(nnx.Module):
+ def __init__(
+ self,
+ config: UMT5Config,
+ *,
+ layer_idx: Optional[int] = None,
+ param_dtype: jnp.dtype | None = jnp.float32,
+ rngs: nnx.Rngs,
+ ) -> None:
+ super().__init__()
+ self.EncDecAttention = UMT5Attention(
+ config, has_relative_attention_bias=False, layer_idx=layer_idx, param_dtype=param_dtype, rngs=rngs
+ )
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, param_dtype=param_dtype)
+ self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
+
+ def __call__(
+ self,
+ hidden_states: jax.Array,
+ encoder_hidden_states: jax.Array = None,
+ attention_mask: jax.Array = None,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.EncDecAttention(
+ normed_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ )
+ return hidden_states + self.dropout(attention_output)
+
+
+class UMT5Block(nnx.Module):
+ def __init__(
+ self,
+ config: UMT5Config,
+ *,
+ layer_idx: Optional[int] = None,
+ param_dtype: jnp.dtype | None = jnp.float32,
+ rngs: nnx.Rngs,
+ ):
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ self.layer = nnx.List()
+ self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx, param_dtype=param_dtype, rngs=rngs))
+ if self.is_decoder:
+ self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx, param_dtype=param_dtype, rngs=rngs))
+
+ self.layer.append(UMT5LayerFF(config, param_dtype=param_dtype, rngs=rngs))
+
+ def __call__(
+ self,
+ hidden_states: jax.Array,
+ attention_mask: jax.Array = None,
+ encoder_hidden_states: jax.Array = None,
+ encoder_attention_mask: jax.Array = None,
+ ):
+ # Apply self-attention layer
+ hidden_states = fp16_clamp(
+ self.layer[0](
+ hidden_states,
+ attention_mask=attention_mask,
+ )
+ )
+ # Cross-Attention Block
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
+ if do_cross_attention:
+ hidden_states = fp16_clamp(
+ self.layer[1](
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ )
+ )
+
+ # Apply Feed Forward layer
+ hidden_states = fp16_clamp(self.layer[-1](hidden_states))
+
+ return (hidden_states,)
+
+
+class UMT5Stack(nnx.Module):
+ def __init__(self, config: UMT5Config, *, param_dtype: jnp.dtype | None = jnp.float32, rngs: nnx.Rngs):
+ super().__init__()
+ self.embed_tokens = nnx.Embed(config.vocab_size, config.d_model, param_dtype=param_dtype, rngs=rngs)
+ self.is_decoder = config.is_decoder
+ self.block = nnx.List(
+ [UMT5Block(config, layer_idx=i, param_dtype=param_dtype, rngs=rngs) for i in range(config.num_layers)]
+ )
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, param_dtype=param_dtype)
+ self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs)
+
+ def _prepare_4d_causal_attention_mask_for_decoder(
+ self,
+ attention_mask: jax.Array,
+ batch_size: int,
+ q_len: int,
+ dtype: DTypeLike,
+ ):
+ if attention_mask is None:
+ causal_mask = jnp.tril(jnp.ones((batch_size, 1, q_len, q_len)))
+ elif attention_mask.ndim == 4:
+ causal_mask = attention_mask
+ elif attention_mask.ndim == 2:
+ # shape of causal_mask is (batch_size, q_len), expand to (batch_size, 1, q_len, q_len)
+ causal_mask = jnp.tril(attention_mask[:, None, None, :].repeat(q_len, axis=2))
+ else:
+ raise ValueError(f"Invalid attention mask ndim, expected ndim: 2, actual ndim: {attention_mask.ndim}")
+ causal_mask = (1.0 - causal_mask) * jnp.finfo(dtype).min
+ return causal_mask
+
+ def _prepare_padding_mask(
+ self,
+ padding_mask: jax.Array,
+ dtype: DTypeLike,
+ ):
+ """
+ For decoder, padding mask is encoder attention mask. For encoder, padding mask is input mask from tokenizer.
+ """
+ assert padding_mask.ndim in [2, 3], f"Invalid padding mask ndim: {padding_mask.ndim}"
+
+ # expand dim if needed
+ if padding_mask.ndim == 3:
+ # shape of encoder_attention_mask is (batch_size, seq_len, seq_len), expand to (batch_size, 1, seq_len, seq_len)
+ padding_mask = padding_mask[:, None, :, :]
+ elif padding_mask.ndim == 2:
+ # shape of encoder_attention_mask is (batch_size, seq_len), expand to (batch_size, 1, 1, seq_len)
+ padding_mask = padding_mask[:, None, None, :]
+
+ # convert to additive biases
+ padding_mask = (1.0 - padding_mask) * jnp.finfo(dtype).min
+ return padding_mask
+
+ def _prepare_attention_mask(
+ self,
+ input_ids,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ dtype: DTypeLike,
+ ):
+ """
+ Prepare attention mask. Only SelfAttention mask is needed in encoder. Both SelfAttention mask and CrossAttention mask are needed in decoder.
+ Args:
+ input_ids: Indices of input sequence tokens in the vocabulary. shape: (batch_size, seq_len).
+ attention_mask: causal mask of input_ids.
+ encoder_hidden_states: The last hidden states of encoder output.
+ encoder_attention_mask: The causal mask of encoder_hidden_states
+ """
+ if self.is_decoder:
+ b, s = input_ids.shape
+ # prepare self-attention causal mask for decoder
+ causal_mask = self._prepare_4d_causal_attention_mask_for_decoder(attention_mask, b, s, dtype)
+ # prepare cross-attention causal mask for decoder
+ if encoder_hidden_states is not None:
+ b, s, _ = encoder_hidden_states.shape
+ # new mask if not provided
+ if encoder_attention_mask is None:
+ encoder_attention_mask = jnp.ones((b, s), dtype=jnp.int32)
+ encoder_attention_mask = self._prepare_padding_mask(encoder_attention_mask, dtype)
+ else:
+ encoder_attention_mask = None
+ elif attention_mask is not None:
+ # prepare padding mask for encoder
+ causal_mask = self._prepare_padding_mask(attention_mask, dtype)
+ else:
+ causal_mask = None
+
+ return causal_mask, encoder_attention_mask
+
+ def __call__(
+ self,
+ input_ids: jax.Array = None,
+ attention_mask: jax.Array = None,
+ encoder_hidden_states: jax.Array = None,
+ encoder_attention_mask: jax.Array = None,
+ ):
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = self.dropout(inputs_embeds)
+
+ # prepare attention mask for encoder and decoder
+ causal_mask, encoder_attention_mask = self._prepare_attention_mask(
+ input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds.dtype
+ )
+
+ for _, layer_module in enumerate(self.block):
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=causal_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ hidden_states = layer_outputs[0]
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class UMT5EncoderModel(nnx.Module):
+ def __init__(
+ self,
+ config: UMT5Config,
+ *,
+ param_dtype: jnp.dtype | None = jnp.float32,
+ rngs: nnx.Rngs,
+ ):
+ super().__init__()
+ config.is_decoder = False
+ self.encoder = UMT5Stack(config, param_dtype=param_dtype, rngs=rngs)
+
+ def __call__(
+ self,
+ input_ids: jax.Array = None,
+ attention_mask: jax.Array = None,
+ ) -> Union[jax.Array]:
+ r"""
+ input_ids (`jax.Array` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so you
+ should be able to pad the inputs on both the right and the left.
+
+ Indices can be obtained using [`AutoTokenizer`].
+
+ To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
+ ```"""
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )
+
+ return encoder_outputs
+
+
+class UMT5Model(nnx.Module):
+ def __init__(
+ self,
+ config: UMT5Config,
+ *,
+ param_dtype: jnp.dtype | None = jnp.float32,
+ rngs: nnx.Rngs,
+ ):
+ super().__init__()
+ self.config = config
+ self.model_dim = config.d_model
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.is_decoder = False
+ self.encoder = UMT5Stack(encoder_config, param_dtype=param_dtype, rngs=rngs)
+
+ decoder_config = copy.deepcopy(config)
+ decoder_config.is_decoder = True
+ decoder_config.num_layers = config.num_decoder_layers
+ self.decoder = UMT5Stack(decoder_config, param_dtype=param_dtype, rngs=rngs)
+
+ self.lm_head = nnx.Linear(
+ config.d_model,
+ config.vocab_size,
+ use_bias=False,
+ precision=Precision.HIGHEST,
+ param_dtype=param_dtype,
+ rngs=rngs,
+ )
+
+ def __call__(
+ self,
+ input_ids: Optional[jax.Array] = None,
+ attention_mask: Optional[jax.Array] = None,
+ decoder_input_ids: Optional[jax.Array] = None,
+ decoder_attention_mask: Optional[jax.Array] = None,
+ encoder_outputs: Optional[jax.Array] = None,
+ ) -> jax.Array:
+ # Encode if needed (training, first prediction pass)
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )
+
+ hidden_states = encoder_outputs
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=attention_mask,
+ )
+
+ return decoder_outputs
+
+ def generate(
+ self,
+ input_ids: jax.Array,
+ attention_mask: jax.Array = None,
+ max_tokens: Optional[int] = None,
+ max_new_tokens: Optional[int] = None,
+ ) -> jax.Array:
+ """Generate sequences using greedy decoding.
+
+ Args:
+ input_ids: Encoder input ids from tokenizer, shape (batch_size, seq_length)
+ attention_mask: Encoder attention mask, shape (batch_size, seq_length)
+ max_tokens: Maximum total length of decoder sequence (including start token).
+ Takes precedence over max_new_tokens if both are provided.
+ max_new_tokens: Maximum number of new tokens to generate (excluding start token)
+
+ Returns:
+ Generated token ids, shape (batch_size, generated_length)
+ """
+ # Determine maximum generation length
+ if max_tokens is not None:
+ max_length = max_tokens
+ elif max_new_tokens is not None:
+ max_length = max_new_tokens + 1 # +1 for decoder_start_token
+ else:
+ max_length = 512 # default value
+
+ # Encode input
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )
+
+ # Initialize decoder input with start token
+ batch_size = input_ids.shape[0]
+ decoder_input_ids = jnp.full((batch_size, 1), self.config.decoder_start_token_id, dtype=jnp.int32)
+
+ # Autoregressive generation loop
+ for _ in range(max_length - 1):
+ # Decoder forward pass
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ encoder_hidden_states=encoder_outputs,
+ )
+
+ # Get logits and select next token (greedy)
+ logits = self.lm_head(decoder_outputs)
+ # here use simple greedy, but beem search is recommended
+ next_token = jnp.argmax(logits[:, -1, :], axis=-1, keepdims=True)
+
+ # Append to decoder input
+ decoder_input_ids = jnp.concatenate([decoder_input_ids, next_token], axis=1)
+
+ # Stop if all sequences generated EOS
+ if jnp.all(next_token == self.config.eos_token_id):
+ break
+
+ return decoder_input_ids
+
+
+@jax.jit
+def forward(
+ graphdef: nnx.GraphDef,
+ state: nnx.State,
+ *,
+ input_ids: Optional[jax.Array] = None,
+ attention_mask: Optional[jax.Array] = None,
+ decoder_input_ids: Optional[jax.Array] = None,
+ decoder_attention_mask: Optional[jax.Array] = None,
+ encoder_outputs: Optional[jax.Array] = None,
+) -> jax.Array:
+ model = nnx.merge(graphdef, state)
+ return model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ )
+
+
+__all__ = ["UMT5EncoderModel", "UMT5Model"]
diff --git a/bonsai/models/umt5/params.py b/bonsai/models/umt5/params.py
new file mode 100644
index 00000000..1df18582
--- /dev/null
+++ b/bonsai/models/umt5/params.py
@@ -0,0 +1,395 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import gc
+import json
+import logging
+import re
+from enum import Enum, auto
+
+import jax
+import jax.numpy as jnp
+import safetensors
+import torch
+from etils import epath
+from flax import nnx
+
+from bonsai.models.umt5 import modeling as model_lib
+
+
+def _get_key_and_transform_mapping(cls, cfg: model_lib.UMT5Config):
+ """Define mapping from HuggingFace UMT5 keys to JAX UMT5 keys."""
+
+ class Transform(Enum):
+ """Transformations for UMT5 parameters"""
+
+ NONE = None
+ # For linear layers: (out, in) -> (in, out)
+ TRANSPOSE = ((1, 0), None, False)
+
+ # T5/UMT5 uses standard HuggingFace naming
+ encoder_mapping = {
+ r"encoder\.embed_tokens\.weight": ("encoder.embed_tokens.embedding", Transform.NONE),
+ # Shared token embeddings
+ r"shared\.weight": ("encoder.embed_tokens.embedding", Transform.NONE),
+ # Encoder
+ # Encoder blocks - Self attention
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.q\.weight": (
+ r"encoder.block.\1.layer.0.SelfAttention.q.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.k\.weight": (
+ r"encoder.block.\1.layer.0.SelfAttention.k.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.v\.weight": (
+ r"encoder.block.\1.layer.0.SelfAttention.v.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.o\.weight": (
+ r"encoder.block.\1.layer.0.SelfAttention.o.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.relative_attention_bias\.weight": (
+ r"encoder.block.\1.layer.0.SelfAttention.relative_attention_bias.embedding",
+ Transform.NONE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.layer_norm\.weight": (
+ r"encoder.block.\1.layer.0.layer_norm.scale",
+ Transform.NONE,
+ ),
+ # Encoder blocks - Feed forward
+ r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wi_0\.weight": (
+ r"encoder.block.\1.layer.1.DenseReluDense.wi_0.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wi_1\.weight": (
+ r"encoder.block.\1.layer.1.DenseReluDense.wi_1.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wo\.weight": (
+ r"encoder.block.\1.layer.1.DenseReluDense.wo.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.1\.layer_norm\.weight": (
+ r"encoder.block.\1.layer.1.layer_norm.scale",
+ Transform.NONE,
+ ),
+ # Encoder Final layer norm
+ r"encoder\.final_layer_norm\.weight": ("encoder.final_layer_norm.scale", Transform.NONE),
+ }
+ decoder_mapping = {
+ # Decoder
+ # Decoder embedding
+ r"decoder\.embed_tokens\.weight": ("decoder.embed_tokens.embedding", Transform.NONE),
+ # Decoder blocks - Self attention
+ r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.q\.weight": (
+ r"decoder.block.\1.layer.0.SelfAttention.q.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.k\.weight": (
+ r"decoder.block.\1.layer.0.SelfAttention.k.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.v\.weight": (
+ r"decoder.block.\1.layer.0.SelfAttention.v.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.o\.weight": (
+ r"decoder.block.\1.layer.0.SelfAttention.o.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.relative_attention_bias\.weight": (
+ r"decoder.block.\1.layer.0.SelfAttention.relative_attention_bias.embedding",
+ Transform.NONE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.0\.layer_norm\.weight": (
+ r"decoder.block.\1.layer.0.layer_norm.scale",
+ Transform.NONE,
+ ),
+ # Decoder blocks - Cross attention
+ r"decoder\.block\.([0-9]+)\.layer\.1\.EncDecAttention\.q\.weight": (
+ r"decoder.block.\1.layer.1.EncDecAttention.q.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.1\.EncDecAttention\.k\.weight": (
+ r"decoder.block.\1.layer.1.EncDecAttention.k.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.1\.EncDecAttention\.v\.weight": (
+ r"decoder.block.\1.layer.1.EncDecAttention.v.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.1\.EncDecAttention\.o\.weight": (
+ r"decoder.block.\1.layer.1.EncDecAttention.o.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.1\.layer_norm\.weight": (
+ r"decoder.block.\1.layer.1.layer_norm.scale",
+ Transform.NONE,
+ ),
+ # Decoder blocks - Feed forward
+ r"decoder\.block\.([0-9]+)\.layer\.2\.DenseReluDense\.wi_0\.weight": (
+ r"decoder.block.\1.layer.2.DenseReluDense.wi_0.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.2\.DenseReluDense\.wi_1\.weight": (
+ r"decoder.block.\1.layer.2.DenseReluDense.wi_1.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.2\.DenseReluDense\.wo\.weight": (
+ r"decoder.block.\1.layer.2.DenseReluDense.wo.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"decoder\.block\.([0-9]+)\.layer\.2\.layer_norm\.weight": (
+ r"decoder.block.\1.layer.2.layer_norm.scale",
+ Transform.NONE,
+ ),
+ # Decoder Final layer norm
+ r"decoder\.final_layer_norm\.weight": ("decoder.final_layer_norm.scale", Transform.NONE),
+ # lm head
+ r"lm_head\.weight": ("lm_head.kernel", Transform.TRANSPOSE),
+ }
+
+ if cls == model_lib.UMT5EncoderModel:
+ return encoder_mapping
+
+ full_mapping = encoder_mapping.copy()
+ full_mapping.update(decoder_mapping)
+ return full_mapping
+
+
+def _torch_key_to_jax_key(mapping, source_key):
+ subs = [
+ (re.sub(pat, repl, source_key), reshape)
+ for pat, (repl, reshape) in mapping.items()
+ if re.match(pat, source_key)
+ ]
+ if len(subs) > 1:
+ raise ValueError(f"Only one key should be found: {subs[0]}")
+ if len(subs) == 0:
+ return (None, None)
+ return subs[0]
+
+
+def _assign_weights(keys, tensor, state_dict, st_key, transform, sharding_dict, dtype):
+ """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor."""
+ key, *rest = keys
+ if not rest:
+ if transform is not None:
+ permute, reshape, reshape_first = transform
+ if reshape_first and reshape is not None:
+ tensor = tensor.reshape(reshape)
+ if permute:
+ tensor = tensor.transpose(permute)
+ if not reshape_first and reshape is not None:
+ tensor = tensor.reshape(reshape)
+ if tensor.shape != state_dict[key].shape:
+ raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}")
+ # Only apply sharding if sharding_dict is provided
+ if sharding_dict is not None:
+ state_dict[key] = jax.device_put(tensor, sharding_dict[key]).astype(dtype)
+ else:
+ state_dict[key] = jax.device_put(tensor).astype(dtype)
+ else:
+ next_sharding = sharding_dict[key] if sharding_dict is not None else None
+ _assign_weights(rest, tensor, state_dict[key], st_key, transform, next_sharding, dtype)
+
+
+def _stoi(s):
+ try:
+ return int(s)
+ except ValueError:
+ return s
+
+
+class WeightFileType(Enum):
+ ST = auto() # safetensor
+ BIN = auto()
+ PT = auto()
+ PTH = auto()
+
+
+def search_available_weight_file(file_path):
+ p = epath.Path(file_path).expanduser()
+ if p.is_file():
+ # if file_path is a file, return [p], type
+ file_ext = p.suffix.lower()
+ if file_ext == ".safetensors":
+ print(f"Using {p} to load weight")
+ return [p], WeightFileType.ST
+ elif file_ext == ".bin":
+ print(f"Using {p} to load weight")
+ return [p], WeightFileType.BIN
+ elif file_ext == ".pt":
+ print(f"Using {p} to load weight")
+ return [p], WeightFileType.PT
+ elif file_ext == ".pth":
+ print(f"Using {p} to load weight")
+ return [p], WeightFileType.PTH
+ else:
+ raise ValueError(f"Unsupported file extension: {file_ext}")
+ else:
+ # if file_path is dir, return all matching files
+ files = list(p.glob("*.safetensors"))
+ if not files:
+ logging.warning(f"No *.safetensors found in {file_path}, try to search others")
+ else:
+ print("Using *.safetensors to load weight")
+ return files, WeightFileType.ST
+
+ files = list(p.glob("*.bin"))
+ if not files:
+ logging.warning(f"No *.bin found in {file_path}, try to search others")
+ else:
+ print("Using *.bin to load weight")
+ return files, WeightFileType.BIN
+
+ files = list(p.glob("*.pth"))
+ if not files:
+ logging.warning(f"No *.pth found in {file_path}, try to search others")
+ else:
+ print("Using *.pth to load weight")
+ return files, WeightFileType.PTH
+
+ raise ValueError(f"No weight file found in {file_path}")
+
+
+def open_weight_file(f, file_type):
+ if file_type in [WeightFileType.BIN, WeightFileType.PTH]:
+ sf = torch.load(f, map_location="cpu")
+ elif file_type == WeightFileType.ST:
+ sf = safetensors.safe_open(f, framework="numpy")
+ else:
+ raise ValueError(f"invalid file type: {file_type}")
+ return sf
+
+
+def get_tensor(sf, torch_key, file_type):
+ if file_type in [WeightFileType.BIN, WeightFileType.PTH]:
+ tensor = sf[torch_key]
+ elif file_type == WeightFileType.ST:
+ tensor = sf.get_tensor(torch_key)
+ else:
+ raise ValueError(f"invalid file type: {file_type}")
+
+ return tensor
+
+
+def create_model(
+ cls,
+ file_dir: str,
+ cfg: model_lib.UMT5Config,
+ key_mapping=None,
+ param_dtype: jnp.dtype | None = jnp.float32,
+ mesh: jax.sharding.Mesh | None = None,
+) -> model_lib.UMT5Model | model_lib.UMT5EncoderModel:
+ """Load weight and create a UMT5Encoder model (memory-optimized).
+
+ Args:
+ cls: model class. UMT5Model and UMT5EncoderModel is available.
+ file_dir: model weight path.
+ cfg: model config. Use 'load_model_config' to get it.
+ key_mapping: model weight key map. Used in unofficial umt5 model, such as Wan2.1
+ param_dtype: model weight dtype.
+ mesh: model weight mesh.
+ Returns:
+ The instance of model defined in cls.
+ """
+ files, file_type = search_available_weight_file(file_dir)
+
+ umt5 = nnx.eval_shape(lambda: cls(cfg, param_dtype=param_dtype, rngs=nnx.Rngs(params=0, dropout=0)))
+ graph_def, abs_state = nnx.split(umt5)
+ state_dict = abs_state.to_pure_dict()
+ # Only use sharding if mesh is provided
+ sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None
+
+ if not key_mapping:
+ key_mapping = _get_key_and_transform_mapping(cls, cfg)
+ conversion_errors = []
+
+ print(f"Loading Weight: {cfg.dtype=}")
+ for f in files:
+ sf = open_weight_file(f, file_type)
+
+ for torch_key in sf.keys():
+ ts = get_tensor(sf, torch_key, file_type)
+ if isinstance(ts, torch.Tensor):
+ npy = ts.numpy() if ts.dtype != torch.bfloat16 else ts.to(dtype=torch.float32).numpy()
+
+ jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key)
+ if jax_key is None:
+ continue
+
+ keys = [_stoi(k) for k in jax_key.split(".")]
+ try:
+ _assign_weights(keys, npy, state_dict, torch_key, transform.value, sharding, cfg.dtype)
+ except Exception as e:
+ full_jax_key = ".".join([str(k) for k in keys])
+ conversion_errors.append(f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}")
+ gc.collect()
+
+ if conversion_errors:
+ full_error_log = "\n".join(conversion_errors)
+ raise RuntimeError(f"Encountered {len(conversion_errors)} weight conversion errors. Log:\n{full_error_log}")
+
+ if cls == model_lib.UMT5Model:
+ state_dict["decoder"]["embed_tokens"]["embedding"] = state_dict["encoder"]["embed_tokens"]["embedding"]
+
+ gc.collect()
+ m = nnx.merge(graph_def, state_dict)
+ m.eval()
+ return m
+
+
+def get_weight_dtype_from_config(conf_dict):
+ def get_dtype(dtype_str):
+ if dtype_str in ["float32", "fp32"]:
+ return jnp.float32
+ elif dtype_str in ["bloat16", "bf16"]:
+ return jnp.bfloat16
+ elif dtype_str in ["float16", "fp16"]:
+ return jnp.float16
+ else:
+ logging.warning(f"Unrecognized dtype: {dtype_str}")
+ return jnp.float32
+
+ if "dtype" in conf_dict:
+ return get_dtype(conf_dict["dtype"])
+ elif "torch_dtype" in conf_dict:
+ return get_dtype(conf_dict["torch_dtype"])
+ else:
+ logging.warning("No 'dtype' config found in config file")
+ return jnp.float32
+
+
+def load_model_config(model_path: str) -> model_lib.UMT5Config:
+ """Load the model config from the model path."""
+ model_dir = epath.Path(model_path).expanduser()
+ config_path = model_dir / "config.json"
+
+ if not config_path.exists():
+ raise FileNotFoundError(f"Config file not found at {config_path}")
+
+ with config_path.open("r") as f:
+ config_dict = json.load(f)
+
+ dtype = get_weight_dtype_from_config(config_dict)
+
+ # Filter config_dict to only include fields defined in UMT5Config
+ config_fields = {f.name for f in dataclasses.fields(model_lib.UMT5Config)}
+ filtered_config = {k: v for k, v in config_dict.items() if k in config_fields}
+
+ return model_lib.UMT5Config(**filtered_config, dtype=dtype)
diff --git a/bonsai/models/umt5/tests/__init__.py b/bonsai/models/umt5/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/bonsai/models/umt5/tests/run_model.py b/bonsai/models/umt5/tests/run_model.py
new file mode 100644
index 00000000..84d081a4
--- /dev/null
+++ b/bonsai/models/umt5/tests/run_model.py
@@ -0,0 +1,75 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Demo script to run UMT5 model inference."""
+
+import jax
+import jax.numpy as jnp
+from flax import nnx
+from huggingface_hub import snapshot_download
+from transformers import AutoTokenizer
+
+from bonsai.models.umt5.modeling import UMT5Model, forward
+from bonsai.models.umt5.params import create_model, load_model_config
+
+
+def main():
+ """Run UMT5 model inference demo."""
+ print("=" * 80)
+ print("UMT5 Model Demo - JAX Implementation")
+ print("=" * 80)
+
+ # Model configuration
+ model_name = "google/umt5-base"
+ model_ckpt_path = snapshot_download(model_name)
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ # Load model config and create model
+ model_conf = load_model_config(model_ckpt_path)
+
+ jax_model = create_model(
+ UMT5Model,
+ file_dir=model_ckpt_path,
+ cfg=model_conf,
+ )
+ graphdef, state = nnx.split(jax_model)
+
+ # Prepare input
+ prompts = [
+ "A beautiful sunset over the ocean with waves crashing on the shore",
+ "translate to French: I love cat",
+ ]
+
+ # Tokenize input
+ inputs = tokenizer(prompts, padding=True, return_tensors="np")
+ input_ids = jnp.array(inputs.input_ids)
+ attention_mask = jnp.array(inputs.attention_mask)
+
+ # forward
+ bs = len(prompts)
+ decoder_input_ids = jnp.full((bs, 1), model_conf.decoder_start_token_id, dtype=jnp.int32)
+ decoder_output = forward(
+ graphdef,
+ state,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ )
+ print(f"Decoder output shape: {decoder_output.shape}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/bonsai/models/umt5/tests/test_umt5.py b/bonsai/models/umt5/tests/test_umt5.py
new file mode 100644
index 00000000..5903fa9b
--- /dev/null
+++ b/bonsai/models/umt5/tests/test_umt5.py
@@ -0,0 +1,160 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test accuracy of jax impelement vs torch transformer impelement."""
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import torch
+from absl.testing import absltest
+from huggingface_hub import snapshot_download
+from transformers import AutoTokenizer
+from transformers import UMT5EncoderModel as TorchUMT5EncoderModel
+from transformers import UMT5Model as TorchUMT5Model
+
+from bonsai.models.umt5.modeling import UMT5Config, UMT5EncoderModel, UMT5Model
+from bonsai.models.umt5.params import create_model, load_model_config
+
+
+def compare_outputs(jax_output: jax.Array, torch_output, name: str, rtol: float = 1e-3, atol: float = 1e-5):
+ """Compare JAX and PyTorch outputs and report differences.
+ Args:
+ jax_output: Output from JAX model
+ torch_output: Output from PyTorch model (torch.Tensor)
+ name: Name of the output being compared
+ rtol: Relative tolerance
+ atol: Absolute tolerance
+ """
+ if torch_output.dtype == torch.bfloat16:
+ torch_output = torch_output.float()
+ if jax_output.dtype == jnp.bfloat16:
+ jax_output = jax_output.astype(jnp.float32)
+
+ # Convert PyTorch to numpy
+ if isinstance(torch_output, torch.Tensor):
+ torch_np = torch_output.detach().cpu().numpy()
+ else:
+ torch_np = np.array(torch_output)
+
+ # Convert JAX to numpy
+ jax_np = np.array(jax_output)
+
+ # Check shapes match
+ if jax_np.shape != torch_np.shape:
+ print(f"β Shape mismatch! jax shape: {jax_np.shape}, torch shape: {torch_np.shape}")
+ return False
+
+ if jax_np.dtype != torch_np.dtype:
+ print(f"β Shape mismatch! jax dtype: {jax_np.dtype}, torch dtype: {torch_np.dtype}")
+ return False
+
+ # Check if within tolerance
+ close = np.allclose(jax_np, torch_np, rtol=rtol, atol=atol)
+
+ if not close:
+ print(f"\nβ Outputs do NOT match (rtol={rtol}, atol={atol})")
+ # Show some mismatched locations
+ mismatch_mask = ~np.isclose(jax_np, torch_np, rtol=rtol, atol=atol)
+ n_mismatches = np.sum(mismatch_mask)
+ print(f" Number of mismatches: {n_mismatches} / {jax_np.size} ({100 * n_mismatches / jax_np.size:.2f}%)")
+
+ return close
+
+
+class UMT5Test(absltest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.model_name = "google/umt5-base"
+ self.tokenizer_name = "google/umt5-base"
+
+ self.model_ckpt_path = snapshot_download(self.model_name)
+ self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
+ self.model_config = load_model_config(self.model_ckpt_path)
+
+ def test_t5_encoder_accuracy(self):
+ jax_t5 = create_model(
+ UMT5EncoderModel,
+ file_dir=self.model_ckpt_path,
+ cfg=self.model_config,
+ )
+
+ hf_t5 = TorchUMT5EncoderModel.from_pretrained(self.model_ckpt_path)
+
+ prompts = [
+ "A beautiful sunset over the ocean with waves crashing on the shore",
+ "translate to French: I love cat",
+ ]
+ torch_inputs = self.tokenizer(prompts, padding=True, return_tensors="pt")
+ jax_inputs = self.tokenizer(prompts, padding=True, return_tensors="np")
+
+ # test encoder accuracy
+ pytorch_output = hf_t5.encoder(input_ids=torch_inputs.input_ids, attention_mask=torch_inputs.attention_mask)
+ jax_output = jax_t5.encoder(
+ input_ids=jnp.array(jax_inputs.input_ids), attention_mask=jnp.array(jax_inputs.attention_mask)
+ )
+
+ torch_embeddings = pytorch_output.last_hidden_state
+
+ seq_lens = torch_inputs.attention_mask.gt(0).sum(dim=1).long()
+ for i in range(jax_output.shape[0]):
+ self.assertTrue(
+ compare_outputs(
+ jax_output[i, : seq_lens[i], :],
+ torch_embeddings[i, : seq_lens[i], :],
+ f"UMT5 Encoder For Prompt: {prompts[i]}",
+ )
+ )
+
+ def test_t5_decoder_accuracy(self):
+ jax_t5 = create_model(
+ UMT5Model,
+ file_dir=self.model_ckpt_path,
+ cfg=self.model_config,
+ )
+
+ hf_t5 = TorchUMT5Model.from_pretrained(self.model_ckpt_path)
+
+ prompts = [
+ "A beautiful sunset over the ocean with waves crashing on the shore",
+ "translate to French: I love cat",
+ ]
+ torch_inputs = self.tokenizer(prompts, padding=True, return_tensors="pt")
+
+ # test encoder accuracy
+ pytorch_output = hf_t5.encoder(input_ids=torch_inputs.input_ids, attention_mask=torch_inputs.attention_mask)
+ encoder_hidden_states = pytorch_output.last_hidden_state
+
+ # test decoder accuracy
+ # use torch encoder output as docoder input
+ bs = encoder_hidden_states.shape[0]
+ decoder_input_ids = [0, 1, 2, 3]
+ torch_input_ids = torch.tensor(decoder_input_ids, dtype=torch.int32).unsqueeze(0).repeat(bs, 1)
+ pytorch_decoder_output = hf_t5.decoder(
+ input_ids=torch_input_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=torch_inputs.attention_mask,
+ )
+ decoder_hidden_states = pytorch_decoder_output.last_hidden_state
+
+ jax_decoder_output = jax_t5.decoder(
+ input_ids=jnp.array(torch_input_ids.numpy()),
+ encoder_hidden_states=jnp.array(encoder_hidden_states.detach().numpy()),
+ encoder_attention_mask=jnp.array(torch_inputs.attention_mask.detach().numpy()),
+ )
+ self.assertTrue(compare_outputs(jax_decoder_output, decoder_hidden_states, "UMT5 Decoder"))
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/bonsai/models/wan2/__init__.py b/bonsai/models/wan2/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/bonsai/models/wan2/params.py b/bonsai/models/wan2/params.py
new file mode 100644
index 00000000..957fa6e5
--- /dev/null
+++ b/bonsai/models/wan2/params.py
@@ -0,0 +1,706 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Weight loading utilities for Wan2.1-T2V-1.3B model."""
+
+import gc
+import re
+from enum import Enum
+
+import jax
+import jax.numpy as jnp
+import safetensors
+from etils import epath
+from flax import nnx
+
+from bonsai.models.wan2 import transformer_wan as model_lib
+from bonsai.models.wan2 import umt5 as t5_lib
+from bonsai.models.wan2 import vae_wan as vae_lib
+
+
+def cast_with_exclusion(path, x, dtype_to_cast):
+ """
+ Casts arrays to dtype_to_cast, but keeps params from any 'norm' layer in float32.
+ """
+
+ exclusion_keywords = [
+ "norm", # For all LayerNorm/GroupNorm layers
+ "time_embed", # The entire time conditioning module
+ "text_proj", # The entire text conditioning module
+ "scale_shift_table", # Catches both the final and the AdaLN tables
+ ]
+
+ path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path)
+
+ if any(keyword in path_str.lower() for keyword in exclusion_keywords):
+ # Keep LayerNorm/GroupNorm weights and biases in full precision
+ return x.astype(jnp.float32)
+ else:
+ # Cast everything else to dtype_to_cast
+ return x.astype(dtype_to_cast)
+
+
+def _get_dit_mapping(cfg: model_lib.TransformerWanModelConfig):
+ class Transform(Enum):
+ """Transformations for model parameters"""
+
+ NONE = None
+ TRANSPOSE = ((1, 0), None) # For linear layers: (out, in) -> (in, out)
+ TRANSPOSE_CONV = ((2, 3, 4, 1, 0), None) # For 3D conv: (out, in, t, h, w) -> (t, h, w, in, out)
+
+ mapping = {
+ # Patch embedding (input projection)
+ r"patch_embedding\.weight": ("patch_embed.kernel", Transform.TRANSPOSE_CONV),
+ r"patch_embedding\.bias": ("patch_embed.bias", Transform.NONE),
+ # Time embedder - Sequential uses integer indices (0, 1, 2), not layers_0
+ r"condition_embedder\.time_embedder\.linear_1\.weight": (
+ "time_embed.time_embedding.layers.0.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"condition_embedder\.time_embedder\.linear_1\.bias": (
+ "time_embed.time_embedding.layers.0.bias",
+ Transform.NONE,
+ ),
+ r"condition_embedder\.time_embedder\.linear_2\.weight": (
+ "time_embed.time_embedding.layers.2.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"condition_embedder\.time_embedder\.linear_2\.bias": (
+ "time_embed.time_embedding.layers.2.bias",
+ Transform.NONE,
+ ),
+ r"condition_embedder\.time_proj\.weight": ("time_embed.time_projection.layers.1.kernel", Transform.TRANSPOSE),
+ r"condition_embedder\.time_proj\.bias": ("time_embed.time_projection.layers.1.bias", Transform.NONE),
+ # Text embedder (projects UMT5 embeddings to hidden dim)
+ r"condition_embedder\.text_embedder\.linear_1\.weight": ("text_proj.layers.0.kernel", Transform.TRANSPOSE),
+ r"condition_embedder\.text_embedder\.linear_1\.bias": ("text_proj.layers.0.bias", Transform.NONE),
+ r"condition_embedder\.text_embedder\.linear_2\.weight": ("text_proj.layers.2.kernel", Transform.TRANSPOSE),
+ r"condition_embedder\.text_embedder\.linear_2\.bias": ("text_proj.layers.2.bias", Transform.NONE),
+ # Transformer blocks - Self attention (attn1)
+ r"blocks\.([0-9]+)\.attn1\.norm_q\.weight": (r"blocks.\1.self_attn.q_norm.scale", Transform.NONE),
+ r"blocks\.([0-9]+)\.attn1\.norm_k\.weight": (r"blocks.\1.self_attn.k_norm.scale", Transform.NONE),
+ r"blocks\.([0-9]+)\.attn1\.to_q\.weight": (r"blocks.\1.self_attn.q_proj.kernel", Transform.TRANSPOSE),
+ r"blocks\.([0-9]+)\.attn1\.to_q\.bias": (r"blocks.\1.self_attn.q_proj.bias", Transform.NONE),
+ r"blocks\.([0-9]+)\.attn1\.to_k\.weight": (r"blocks.\1.self_attn.k_proj.kernel", Transform.TRANSPOSE),
+ r"blocks\.([0-9]+)\.attn1\.to_k\.bias": (r"blocks.\1.self_attn.k_proj.bias", Transform.NONE),
+ r"blocks\.([0-9]+)\.attn1\.to_v\.weight": (r"blocks.\1.self_attn.v_proj.kernel", Transform.TRANSPOSE),
+ r"blocks\.([0-9]+)\.attn1\.to_v\.bias": (r"blocks.\1.self_attn.v_proj.bias", Transform.NONE),
+ r"blocks\.([0-9]+)\.attn1\.to_out\.0\.weight": (r"blocks.\1.self_attn.out_proj.kernel", Transform.TRANSPOSE),
+ r"blocks\.([0-9]+)\.attn1\.to_out\.0\.bias": (r"blocks.\1.self_attn.out_proj.bias", Transform.NONE),
+ # Transformer blocks - Cross attention (attn2)
+ # Note: CrossAttention only has q_norm, not k_norm; norm_k is skipped
+ r"blocks\.([0-9]+)\.attn2\.norm_q\.weight": (r"blocks.\1.cross_attn.q_norm.scale", Transform.NONE),
+ r"blocks\.([0-9]+)\.attn2\.norm_k\.weight": (r"blocks.\1.cross_attn.k_norm.scale", Transform.NONE),
+ r"blocks\.([0-9]+)\.attn2\.to_q\.weight": (r"blocks.\1.cross_attn.q_proj.kernel", Transform.TRANSPOSE),
+ r"blocks\.([0-9]+)\.attn2\.to_q\.bias": (r"blocks.\1.cross_attn.q_proj.bias", Transform.NONE),
+ # Note: to_k and to_v need special handling - they're fused into kv_proj in JAX
+ # See _load_fused_kv_weights() below
+ r"blocks\.([0-9]+)\.attn2\.to_out\.0\.weight": (r"blocks.\1.cross_attn.out_proj.kernel", Transform.TRANSPOSE),
+ r"blocks\.([0-9]+)\.attn2\.to_out\.0\.bias": (r"blocks.\1.cross_attn.out_proj.bias", Transform.NONE),
+ # Transformer blocks - Feed forward (Sequential creates 'layers' dict with 0, 2 keys)
+ r"blocks\.([0-9]+)\.ffn\.net\.0\.proj\.weight": (r"blocks.\1.mlp.layers.0.kernel", Transform.TRANSPOSE),
+ r"blocks\.([0-9]+)\.ffn\.net\.0\.proj\.bias": (r"blocks.\1.mlp.layers.0.bias", Transform.NONE),
+ r"blocks\.([0-9]+)\.ffn\.net\.2\.weight": (r"blocks.\1.mlp.layers.2.kernel", Transform.TRANSPOSE),
+ r"blocks\.([0-9]+)\.ffn\.net\.2\.bias": (r"blocks.\1.mlp.layers.2.bias", Transform.NONE),
+ # Transformer blocks - Norm and modulation
+ r"blocks\.([0-9]+)\.norm2\.weight": (r"blocks.\1.norm2.scale", Transform.NONE),
+ r"blocks\.([0-9]+)\.norm2\.bias": (r"blocks.\1.norm2.bias", Transform.NONE),
+ r"blocks\.([0-9]+)\.scale_shift_table": (r"blocks.\1.scale_shift_table", Transform.NONE),
+ # Output projection
+ r"scale_shift_table": ("final_layer.scale_shift_table", Transform.NONE),
+ r"proj_out\.weight": ("final_layer.linear.kernel", Transform.TRANSPOSE),
+ r"proj_out\.bias": ("final_layer.linear.bias", Transform.NONE),
+ r"norm_out\.weight": ("final_layer.norm.scale", Transform.NONE),
+ }
+
+ return mapping
+
+
+def _get_vae_key_mapping():
+ """Define mapping from PyTorch VAE keys to JAX VAE keys."""
+
+ class Transform(Enum):
+ """Transformations for VAE parameters"""
+
+ NONE = None
+ TRANSPOSE_2D_CONV = ((2, 3, 1, 0), None) # For 2D conv: (out, in, h, w) -> (h, w, in, out)
+ TRANSPOSE_3D = ((2, 3, 4, 1, 0), None) # For 3D conv: (out, in, t, h, w) -> (t, h, w, in, out)
+ SQUEEZE = (None, (-1,)) # Squeeze to 1D: (C, 1, 1, 1) -> (C,)
+
+ # PyTorch format: (out_channels, in_channels, kernel_size...)
+ # JAX format: (kernel_size..., in_channels, out_channels)
+ mapping = {
+ # Post-quantization conv: 1x1x1 conv
+ r"post_quant_conv\.weight": ("conv2.conv.kernel", Transform.TRANSPOSE_3D),
+ r"post_quant_conv\.bias": ("conv2.conv.bias", Transform.NONE),
+ # Decoder input conv
+ r"decoder\.conv_in\.weight": ("decoder.conv_in.conv.kernel", Transform.TRANSPOSE_3D),
+ r"decoder\.conv_in\.bias": ("decoder.conv_in.conv.bias", Transform.NONE),
+ # Mid block resnets
+ r"decoder\.mid_block\.resnets\.0\.norm1\.gamma": ("decoder.mid_block1.norm1.scale", Transform.SQUEEZE),
+ r"decoder\.mid_block\.resnets\.0\.conv1\.weight": (
+ "decoder.mid_block1.conv1.conv.kernel",
+ Transform.TRANSPOSE_3D,
+ ),
+ r"decoder\.mid_block\.resnets\.0\.conv1\.bias": ("decoder.mid_block1.conv1.conv.bias", Transform.NONE),
+ r"decoder\.mid_block\.resnets\.0\.norm2\.gamma": ("decoder.mid_block1.norm2.scale", Transform.SQUEEZE),
+ r"decoder\.mid_block\.resnets\.0\.norm2\.bias": ("decoder.mid_block1.norm2.scale", Transform.NONE),
+ r"decoder\.mid_block\.resnets\.0\.conv2\.weight": (
+ "decoder.mid_block1.conv2.conv.kernel",
+ Transform.TRANSPOSE_3D,
+ ),
+ r"decoder\.mid_block\.resnets\.0\.conv2\.bias": ("decoder.mid_block1.conv2.conv.bias", Transform.NONE),
+ r"decoder\.mid_block\.resnets\.1\.norm1\.gamma": ("decoder.mid_block2.norm1.scale", Transform.SQUEEZE),
+ r"decoder\.mid_block\.resnets\.1\.conv1\.weight": (
+ "decoder.mid_block2.conv1.conv.kernel",
+ Transform.TRANSPOSE_3D,
+ ),
+ r"decoder\.mid_block\.resnets\.1\.conv1\.bias": ("decoder.mid_block2.conv1.conv.bias", Transform.NONE),
+ r"decoder\.mid_block\.resnets\.1\.norm2\.gamma": ("decoder.mid_block2.norm2.scale", Transform.SQUEEZE),
+ r"decoder\.mid_block\.resnets\.1\.conv2\.weight": (
+ "decoder.mid_block2.conv2.conv.kernel",
+ Transform.TRANSPOSE_3D,
+ ),
+ r"decoder\.mid_block\.resnets\.1\.conv2\.bias": ("decoder.mid_block2.conv2.conv.bias", Transform.NONE),
+ # Mid attention block
+ r"decoder\.mid_block\.attentions\.0\.norm\.gamma": ("decoder.mid_attn.norm.scale", Transform.SQUEEZE),
+ r"decoder\.mid_block\.attentions\.0\.to_qkv\.weight": (
+ "decoder.mid_attn.qkv.kernel",
+ Transform.TRANSPOSE_2D_CONV,
+ ),
+ r"decoder\.mid_block\.attentions\.0\.to_qkv\.bias": ("decoder.mid_attn.qkv.bias", Transform.NONE),
+ r"decoder\.mid_block\.attentions\.0\.proj\.weight": (
+ "decoder.mid_attn.proj.kernel",
+ Transform.TRANSPOSE_2D_CONV,
+ ),
+ r"decoder\.mid_block\.attentions\.0\.proj\.bias": ("decoder.mid_attn.proj.bias", Transform.NONE),
+ # Up blocks - resnets (pattern for all 4 stages, 3 resnets each)
+ r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.norm1\.gamma": (
+ r"decoder.up_blocks_\1.\2.norm1.scale",
+ Transform.SQUEEZE,
+ ),
+ r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.conv1\.weight": (
+ r"decoder.up_blocks_\1.\2.conv1.conv.kernel",
+ Transform.TRANSPOSE_3D,
+ ),
+ r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.conv1\.bias": (
+ r"decoder.up_blocks_\1.\2.conv1.conv.bias",
+ Transform.NONE,
+ ),
+ r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.norm2\.gamma": (
+ r"decoder.up_blocks_\1.\2.norm2.scale",
+ Transform.SQUEEZE,
+ ),
+ r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.conv2\.weight": (
+ r"decoder.up_blocks_\1.\2.conv2.conv.kernel",
+ Transform.TRANSPOSE_3D,
+ ),
+ r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.conv2\.bias": (
+ r"decoder.up_blocks_\1.\2.conv2.conv.bias",
+ Transform.NONE,
+ ),
+ # Skip connections (only in block 1, resnet 0)
+ r"decoder\.up_blocks\.1\.resnets\.0\.conv_shortcut\.weight": (
+ "decoder.up_blocks_1.0.skip_conv.conv.kernel",
+ Transform.TRANSPOSE_3D,
+ ),
+ r"decoder\.up_blocks\.1\.resnets\.0\.conv_shortcut\.bias": (
+ "decoder.up_blocks_1.0.skip_conv.conv.bias",
+ Transform.NONE,
+ ),
+ # Upsamplers for blocks 0, 1, 2 (block 3 has no upsampler)
+ # Block 0: Upsample3D (time_conv + spatial_conv)
+ r"decoder\.up_blocks\.0\.upsamplers\.0\.time_conv\.weight": (
+ "decoder.up_sample_0.time_conv.conv.kernel",
+ Transform.TRANSPOSE_3D,
+ ),
+ r"decoder\.up_blocks\.0\.upsamplers\.0\.time_conv\.bias": (
+ "decoder.up_sample_0.time_conv.conv.bias",
+ Transform.NONE,
+ ),
+ r"decoder\.up_blocks\.0\.upsamplers\.0\.resample\.1\.weight": (
+ "decoder.up_sample_0.spatial_conv.kernel",
+ Transform.TRANSPOSE_2D_CONV,
+ ),
+ r"decoder\.up_blocks\.0\.upsamplers\.0\.resample\.1\.bias": (
+ "decoder.up_sample_0.spatial_conv.bias",
+ Transform.NONE,
+ ),
+ # Block 1: Upsample3D (time_conv + spatial_conv)
+ r"decoder\.up_blocks\.1\.upsamplers\.0\.time_conv\.weight": (
+ "decoder.up_sample_1.time_conv.conv.kernel",
+ Transform.TRANSPOSE_3D,
+ ),
+ r"decoder\.up_blocks\.1\.upsamplers\.0\.time_conv\.bias": (
+ "decoder.up_sample_1.time_conv.conv.bias",
+ Transform.NONE,
+ ),
+ r"decoder\.up_blocks\.1\.upsamplers\.0\.resample\.1\.weight": (
+ "decoder.up_sample_1.spatial_conv.kernel",
+ Transform.TRANSPOSE_2D_CONV,
+ ),
+ r"decoder\.up_blocks\.1\.upsamplers\.0\.resample\.1\.bias": (
+ "decoder.up_sample_1.spatial_conv.bias",
+ Transform.NONE,
+ ),
+ # Block 2: Upsample2D (conv only, no time_conv)
+ r"decoder\.up_blocks\.2\.upsamplers\.0\.resample\.1\.weight": (
+ "decoder.up_sample_2.conv.kernel",
+ Transform.TRANSPOSE_2D_CONV,
+ ),
+ r"decoder\.up_blocks\.2\.upsamplers\.0\.resample\.1\.bias": ("decoder.up_sample_2.conv.bias", Transform.NONE),
+ # Output layers
+ r"decoder\.norm_out\.gamma": ("decoder.norm_out.scale", Transform.SQUEEZE),
+ r"decoder\.conv_out\.weight": ("decoder.conv_out.conv.kernel", Transform.TRANSPOSE_3D),
+ r"decoder\.conv_out\.bias": ("decoder.conv_out.conv.bias", Transform.NONE),
+ }
+
+ return mapping
+
+
+def _torch_key_to_jax_key(mapping, source_key):
+ """Convert a PyTorch/Diffusers key to JAX key with transform info."""
+ subs = [
+ (re.sub(pat, repl, source_key), transform)
+ for pat, (repl, transform) in mapping.items()
+ if re.match(pat, source_key)
+ ]
+ if len(subs) == 0:
+ # Key not found in mapping, might be OK (e.g., VAE weights)
+ return None, None
+ if len(subs) > 1:
+ raise ValueError(f"Multiple patterns matched for key {source_key}: {subs}")
+ return subs[0]
+
+
+def _assign_weights(keys, tensor, state_dict, st_key, transform, sharding_dict=None):
+ """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor."""
+ key, *rest = keys
+ if not rest:
+ if transform is not None and transform.value is not None:
+ permute, reshape = transform.value
+ if reshape is not None:
+ tensor = tensor.reshape(reshape)
+ if permute:
+ tensor = tensor.transpose(permute)
+
+ if key not in state_dict:
+ raise KeyError(f"Key {key} not found in state_dict. Available keys: {list(state_dict.keys())[:10]}...")
+
+ if tensor.shape != state_dict[key].shape:
+ raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}")
+
+ # Assign with or without sharding
+ if sharding_dict is not None and key in sharding_dict:
+ state_dict[key] = jax.device_put(tensor, sharding_dict[key])
+ else:
+ state_dict[key] = jax.device_put(tensor)
+ else:
+ next_sharding = sharding_dict[key] if sharding_dict is not None and key in sharding_dict else None
+ _assign_weights(rest, tensor, state_dict[key], st_key, transform, next_sharding)
+
+
+def _stoi(s):
+ """Convert string to int if possible, otherwise return string."""
+ try:
+ return int(s)
+ except ValueError:
+ return s
+
+
+def create_model_from_safe_tensors(
+ file_dir: str,
+ cfg: model_lib.TransformerWanModelConfig,
+ mesh: jax.sharding.Mesh | None = None,
+) -> model_lib.Wan2DiT:
+ """
+ Load Wan2.1-T2V-1.3B DiT model from safetensors checkpoint.
+
+ Args:
+ file_dir: Directory containing .safetensors files or path to transformer directory
+ cfg: Model configuration
+ mesh: Optional JAX mesh for sharding
+ load_transformer_only: If True, only load transformer weights (not VAE/text encoder)
+
+ Returns:
+ Wan2DiT model with loaded weights
+ """
+ # Check if file_dir is the model root or transformer subdirectory
+ file_path = epath.Path(file_dir).expanduser()
+ transformer_path = file_path / "transformer"
+
+ if transformer_path.exists():
+ # Look in transformer subdirectory
+ files = sorted(list(transformer_path.glob("diffusion_pytorch_model-*.safetensors")))
+ else:
+ # Look in provided directory
+ files = sorted(list(file_path.glob("diffusion_pytorch_model-*.safetensors")))
+ if not files:
+ files = sorted(list(file_path.glob("*.safetensors")))
+
+ if not files:
+ raise ValueError(f"No safetensors found in {file_dir} or {file_dir}/transformer")
+
+ print(f"Found {len(files)} DiT transformer safetensors file(s)")
+
+ # Create model structure
+ wan2_dit = nnx.eval_shape(lambda: model_lib.Wan2DiT(cfg, rngs=nnx.Rngs(params=0)))
+ graph_def, abs_state = nnx.split(wan2_dit)
+ state_dict = abs_state.to_pure_dict()
+
+ # Setup sharding if mesh provided
+ sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None
+
+ key_mapping = _get_dit_mapping(cfg)
+ conversion_errors = []
+ loaded_keys = []
+ skipped_keys = []
+
+ # Collect K/V weights for fusion into kv_proj
+ kv_weights = {} # {block_idx: {'k_weight': ..., 'k_bias': ..., 'v_weight': ..., 'v_bias': ...}}
+
+ for f in files:
+ print(f"Loading weights from {f.name}...")
+ with safetensors.safe_open(f, framework="numpy") as sf:
+ for torch_key in sf.keys():
+ tensor = sf.get_tensor(torch_key)
+
+ # Special handling for cross-attention K/V fusion
+ kv_match = re.match(r"blocks\.([0-9]+)\.attn2\.to_([kv])\.(weight|bias)", torch_key)
+ if kv_match:
+ block_idx = int(kv_match.group(1))
+ kv_type = kv_match.group(2) # 'k' or 'v'
+ param_type = kv_match.group(3) # 'weight' or 'bias'
+
+ if block_idx not in kv_weights:
+ kv_weights[block_idx] = {}
+ kv_weights[block_idx][f"{kv_type}_{param_type}"] = tensor
+ loaded_keys.append(torch_key)
+ continue
+
+ jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key)
+
+ if jax_key is None:
+ # Skip keys not in our mapping (e.g., VAE, text encoder, attn2.norm_k)
+ skipped_keys.append(torch_key)
+ continue
+
+ keys = [_stoi(k) for k in jax_key.split(".")]
+ try:
+ _assign_weights(keys, tensor, state_dict, torch_key, transform, sharding)
+ loaded_keys.append(torch_key)
+ except Exception as e:
+ full_jax_key = ".".join([str(k) for k in keys])
+ conversion_errors.append(
+ f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}"
+ )
+ gc.collect()
+
+ # Fuse collected K/V weights into kv_proj
+ import jax.numpy as jnp
+
+ for block_idx, weights in kv_weights.items():
+ if all(k in weights for k in ["k_weight", "k_bias", "v_weight", "v_bias"]):
+ # Transpose and concatenate: (out, in) -> (in, out) then concat -> (in, 2*out)
+ k_weight = weights["k_weight"].T # (in, out)
+ v_weight = weights["v_weight"].T # (in, out)
+ kv_kernel = jnp.concatenate([k_weight, v_weight], axis=1) # (in, 2*out)
+
+ kv_bias = jnp.concatenate([weights["k_bias"], weights["v_bias"]]) # (2*out,)
+
+ # Assign to state dict
+ state_dict["blocks"][block_idx]["cross_attn"]["kv_proj"]["kernel"] = jax.device_put(kv_kernel)
+ state_dict["blocks"][block_idx]["cross_attn"]["kv_proj"]["bias"] = jax.device_put(kv_bias)
+
+ print(f"Loaded {len(loaded_keys)} weight tensors")
+ print(f"Skipped {len(skipped_keys)} weight tensors (VAE/text encoder/attn2.norm_k)")
+
+ state_dict = jax.tree_util.tree_map_with_path(
+ lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=cfg.weights_dtype), state_dict
+ )
+
+ if conversion_errors:
+ print(f"\n Warning: {len(conversion_errors)} conversion errors occurred:")
+ for err in conversion_errors: # Show first 5 errors
+ print(f" {err}")
+ # if len(conversion_errors) > 5:
+ # print(f" ... and {len(conversion_errors) - 5} more")
+
+ gc.collect()
+ return nnx.merge(graph_def, state_dict)
+
+
+def create_vae_decoder_from_safe_tensors(
+ file_dir: str,
+ mesh: jax.sharding.Mesh | None = None,
+) -> vae_lib.WanVAEDecoder:
+ """
+ Load Wan-VAE decoder from safetensors checkpoint.
+
+ Args:
+ file_dir: Directory containing .safetensors files or path to VAE directory
+ mesh: Optional JAX mesh for sharding
+
+ Returns:
+ WanVAEDecoder with loaded weights
+ """
+ # Check if file_dir is the model root or VAE subdirectory
+ file_path = epath.Path(file_dir).expanduser()
+ vae_path = file_path / "vae"
+
+ if vae_path.exists():
+ # Look in vae subdirectory
+ files = list(vae_path.glob("*.safetensors"))
+ else:
+ # Look in provided directory
+ files = list(file_path.glob("*.safetensors"))
+
+ if not files:
+ raise ValueError(f"No safetensors found in {file_dir} or {file_dir}/vae")
+
+ print(f"Found {len(files)} VAE safetensors file(s)")
+
+ # Create VAE decoder structure
+ vae_decoder = nnx.eval_shape(lambda: vae_lib.WanVAEDecoder(rngs=nnx.Rngs(params=0)))
+ graph_def, abs_state = nnx.split(vae_decoder)
+ state_dict = abs_state.to_pure_dict()
+
+ # Setup sharding if mesh provided
+ sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None
+
+ key_mapping = _get_vae_key_mapping()
+ conversion_errors = []
+ loaded_keys = []
+ skipped_keys = []
+
+ for f in files:
+ print(f"Loading VAE weights from {f.name}...")
+ with safetensors.safe_open(f, framework="numpy") as sf:
+ for torch_key in sf.keys():
+ tensor = sf.get_tensor(torch_key)
+
+ jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key)
+
+ if jax_key is None:
+ skipped_keys.append(torch_key)
+ # print(f"{torch_key} is not mapped")
+ continue
+
+ keys = [_stoi(k) for k in jax_key.split(".")]
+ try:
+ _assign_weights(keys, tensor, state_dict, torch_key, transform, sharding)
+ loaded_keys.append(torch_key)
+ except Exception as e:
+ full_jax_key = ".".join([str(k) for k in keys])
+ conversion_errors.append(
+ f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}"
+ )
+ gc.collect()
+
+ print(f"Loaded {len(loaded_keys)} VAE weight tensors")
+ print(f"Skipped {len(skipped_keys)} weight tensors")
+
+ if conversion_errors:
+ print(f"\nWarning: {len(conversion_errors)} conversion errors occurred:")
+ for err in conversion_errors: # Show first 10 errors
+ print(f" {err}")
+ # if len(conversion_errors) > 10:
+ # print(f" ... and {len(conversion_errors) - 10} more")
+
+ if len(loaded_keys) == 0:
+ raise ValueError("No VAE weights were loaded! Check the checkpoint structure and key mapping.")
+
+ gc.collect()
+ return nnx.merge(graph_def, state_dict)
+
+
+def _get_t5_key_mapping():
+ """Define mapping from HuggingFace UMT5 keys to JAX UMT5 keys."""
+
+ class Transform(Enum):
+ """Transformations for UMT5 parameters"""
+
+ NONE = None
+ TRANSPOSE = ((1, 0), None) # For linear layers: (out, in) -> (in, out)
+
+ # UMT5/UMT5 uses standard HuggingFace naming
+ mapping = {
+ # Shared token embeddings
+ r"shared\.weight": ("encoder.token_embedding.embedding", Transform.NONE),
+ r"encoder\.embed_tokens\.weight": ("encoder.token_embedding.embedding", Transform.NONE),
+ # Encoder blocks - Self attention
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.q\.weight": (
+ r"encoder.blocks.\1.attn.q.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.k\.weight": (
+ r"encoder.blocks.\1.attn.k.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.v\.weight": (
+ r"encoder.blocks.\1.attn.v.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.o\.weight": (
+ r"encoder.blocks.\1.attn.o.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.relative_attention_bias\.weight": (
+ r"encoder.blocks.\1.pos_embedding.embedding.embedding",
+ Transform.NONE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.0\.layer_norm\.weight": (r"encoder.blocks.\1.norm1.weight", Transform.NONE),
+ # Encoder blocks - Feed forward
+ r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wi_0\.weight": (
+ r"encoder.blocks.\1.ffn.gate.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wi_1\.weight": (
+ r"encoder.blocks.\1.ffn.fc1.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wo\.weight": (
+ r"encoder.blocks.\1.ffn.fc2.kernel",
+ Transform.TRANSPOSE,
+ ),
+ r"encoder\.block\.([0-9]+)\.layer\.1\.layer_norm\.weight": (r"encoder.blocks.\1.norm2.weight", Transform.NONE),
+ # Final layer norm
+ r"encoder\.final_layer_norm\.weight": ("encoder.norm.weight", Transform.NONE),
+ }
+
+ return mapping
+
+
+def create_t5_encoder_from_safe_tensors(
+ file_dir: str,
+ mesh: jax.sharding.Mesh | None = None,
+ is_sf: bool = True,
+ config: t5_lib.T5Config | None = None,
+) -> t5_lib.T5EncoderModel:
+ """
+ Load UMT5 encoder from safetensors checkpoint.
+
+ Args:
+ file_dir: Directory containing .safetensors files or path to text_encoder directory
+ mesh: Optional JAX mesh for sharding
+ is_sf: Whether to load from safetensors (True) or PyTorch checkpoint (False)
+ config: T5Config to use. If None, defaults to UMT5-XXL
+
+ Returns:
+ T5EncoderModel with loaded weights
+ """
+ from bonsai.models.wan2 import umt5
+
+ # Use provided config or default to UMT5-XXL
+ if config is None:
+ config = umt5.T5Config.umt5_xxl()
+
+ t5_encoder = nnx.eval_shape(lambda: umt5.T5EncoderModel(config, rngs=nnx.Rngs(params=0, dropout=0)))
+ graph_def, abs_state = nnx.split(t5_encoder)
+ state_dict = abs_state.to_pure_dict()
+
+ sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None
+
+ key_mapping = _get_t5_key_mapping()
+ conversion_errors = []
+ loaded_keys = []
+ skipped_keys = []
+
+ # Check if file_dir is the model root or text_encoder subdirectory
+ file_path = epath.Path(file_dir).expanduser()
+ text_encoder_path = file_path / "text_encoder"
+
+ def load_pytorch_weights(file_dir):
+ from transformers import UMT5ForConditionalGeneration
+
+ model = UMT5ForConditionalGeneration.from_pretrained(file_dir)
+ encoder_state = {k: v for k, v in model.state_dict().items() if k.startswith("encoder.")}
+ return encoder_state
+
+ if is_sf:
+ if text_encoder_path.exists():
+ files = sorted(list(text_encoder_path.glob("model-*.safetensors")))
+ else:
+ files = sorted(list(file_path.glob("*.safetensors")))
+ if not files:
+ raise ValueError(f"No safetensors found in {file_dir} or {file_dir}/text_encoder")
+ print(f"Found {len(files)} UMT5 encoder safetensors file(s)")
+
+ for f in files:
+ print(f"Loading UMT5 weights from {f.name}...")
+ with safetensors.safe_open(f, framework="numpy") as sf:
+ for torch_key in sf.keys():
+ tensor = sf.get_tensor(torch_key)
+
+ jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key)
+
+ if jax_key is None:
+ # Skip keys not in our mapping
+ skipped_keys.append(torch_key)
+ # print(f"{torch_key} is not mapped")
+ continue
+
+ keys = [_stoi(k) for k in jax_key.split(".")]
+ try:
+ _assign_weights(keys, tensor, state_dict, torch_key, transform, sharding)
+ loaded_keys.append(torch_key)
+ except Exception as e:
+ full_jax_key = ".".join([str(k) for k in keys])
+ conversion_errors.append(
+ f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}"
+ )
+ gc.collect()
+ else:
+ print(f"Loading UMT5 weights from PyTorch checkpoint in {file_dir}...")
+ pt_state = load_pytorch_weights(file_dir)
+ for torch_key, tensor in pt_state.items():
+ jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key)
+
+ if jax_key is None:
+ # Skip keys not in our mapping
+ skipped_keys.append(torch_key)
+ # print(f"{torch_key} is not mapped")
+ continue
+
+ keys = [_stoi(k) for k in jax_key.split(".")]
+ try:
+ _assign_weights(keys, tensor.numpy(), state_dict, torch_key, transform, sharding)
+ loaded_keys.append(torch_key)
+ except Exception as e:
+ full_jax_key = ".".join([str(k) for k in keys])
+ conversion_errors.append(f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}")
+ gc.collect()
+
+ print(f"Loaded {len(loaded_keys)} UMT5 weight tensors")
+ print(f"Skipped {len(skipped_keys)} weight tensors")
+
+ if conversion_errors:
+ print(f"\nWarning: {len(conversion_errors)} conversion errors occurred:")
+ for err in conversion_errors: # Show first 10 errors
+ print(f" {err}")
+ # if len(conversion_errors) > 10:
+ # print(f" ... and {len(conversion_errors) - 10} more")
+
+ if len(loaded_keys) == 0:
+ raise ValueError("No UMT5 weights were loaded! Check the checkpoint structure and key mapping.")
+
+ gc.collect()
+ return nnx.merge(graph_def, state_dict)
+
+
+__all__ = [
+ "create_model_from_safe_tensors",
+ "create_t5_encoder_from_safe_tensors",
+ "create_vae_decoder_from_safe_tensors",
+]
diff --git a/bonsai/models/wan2/pipeline_wan.py b/bonsai/models/wan2/pipeline_wan.py
new file mode 100644
index 00000000..78bab7f5
--- /dev/null
+++ b/bonsai/models/wan2/pipeline_wan.py
@@ -0,0 +1,245 @@
+"""Example script for running Wan2.1-T2V-1.3B text-to-video generation."""
+
+import argparse
+import time
+import traceback
+from typing import Optional, Tuple
+
+import jax
+import jax.numpy as jnp
+from flax import nnx
+from huggingface_hub import snapshot_download
+from jaxtyping import Array
+from transformers import AutoTokenizer
+
+from bonsai.models.wan2 import params, transformer_wan, umt5, vae_wan
+from bonsai.models.wan2.transformer_wan import TransformerWanModelConfig, Wan2DiT
+from bonsai.models.wan2.unipc_multistep_scheduler import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
+
+jax.config.update("jax_debug_nans", True)
+
+
+def get_t5_text_embeddings(
+ prompt: str,
+ tokenizer: AutoTokenizer = None,
+ text_encoder: umt5.T5EncoderModel = None,
+ max_length: int = 512,
+ dtype: jnp.dtype = jnp.float32,
+):
+ try:
+ inputs = tokenizer(
+ prompt,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="np", # Return numpy arrays
+ )
+ input_ids = jnp.array(inputs["input_ids"])
+ attention_mask = jnp.array(inputs["attention_mask"])
+ seq_lens = jnp.sum(attention_mask, axis=1).astype(jnp.int32)
+ print(f"seq_lens: {seq_lens}")
+ embeddings = text_encoder(input_ids, attention_mask, deterministic=True)
+ prompt_embeds = jnp.asarray(embeddings, dtype=dtype)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = jnp.stack(
+ [
+ jnp.concatenate(
+ [u, jnp.zeros((max_length - u.shape[0], u.shape[1]), dtype=u.dtype)],
+ axis=0,
+ )
+ for u in prompt_embeds
+ ],
+ axis=0,
+ )
+ return prompt_embeds
+ except Exception as e:
+ print(f"Error in text encoding: {e}")
+ traceback.print_exc()
+
+
+def decode_video_latents(latents: jax.Array, vae_decoder: vae_wan.WanVAEDecoder):
+ """
+ Decode video latents to RGB frames using Wan-VAE.
+
+ Args:
+ latents: [B, T, H, W, C] video latents
+ vae_decoder: Optional WanVAEDecoder instance. If None, returns dummy video.
+
+ Returns:
+ video: [B, T, H_out, W_out, 3] RGB video (uint8)
+ """
+ # Decode using VAE
+ video = vae_wan.decode_latents_to_video(vae_decoder, latents, normalize=True)
+ return video
+
+
+def generate_video(
+ model: Wan2DiT,
+ latents: Array,
+ text_embeds: Array,
+ negative_embeds: Array,
+ num_steps: int = 50,
+ guidance_scale: float = 5.5,
+ scheduler: Optional[FlaxUniPCMultistepScheduler] = None,
+ scheduler_state: Optional[UniPCMultistepSchedulerState] = None,
+) -> Array:
+ """
+ Generate video from text embeddings using the diffusion model.
+
+ Args:
+ model: Wan2DiT model
+ text_embeds: [B, seq_len, text_dim] text embeddings from UMT5
+ num_frames: Number of frames to generate
+ latent_size: Spatial size of latents
+ num_steps: Number of denoising steps
+ guidance_scale: Classifier-free guidance scale (5-6 recommended)
+
+ Returns:
+ latents: [B, T, H, W, C] generated video latents
+ """
+ b = text_embeds.shape[0]
+
+ # Initialize random noise
+ scheduler_state = scheduler.set_timesteps(
+ scheduler_state, num_inference_steps=num_steps, shape=latents.transpose(0, 4, 1, 2, 3).shape
+ )
+
+ for t_idx in range(num_steps):
+ # Scheduler needs scalar timestep, model needs batched timestep
+ t_scalar = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[t_idx]
+ t_batch = jnp.full((b,), t_scalar, dtype=jnp.int32)
+
+ # Classifier-free guidance
+ if guidance_scale != 1.0:
+ noise_pred_cond = model.forward(latents, text_embeds, t_batch, deterministic=True)
+ noise_pred_uncond = model.forward(latents, negative_embeds, t_batch, deterministic=True)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = model.forward(latents, text_embeds, t_batch, deterministic=True)
+
+ latents, scheduler_state = scheduler.step(
+ scheduler_state, noise_pred.transpose(0, 4, 1, 2, 3), t_scalar, latents.transpose(0, 4, 1, 2, 3)
+ )
+ latents = latents.transpose(0, 2, 3, 4, 1) # back to channel-last
+
+ return latents
+
+
+def run_model(prompt: Optional[str] = None, neg_prompt: Optional[str] = None) -> jax.Array:
+ print("=" * 60)
+ print("Wan2.1-T2V-1.3B Text-to-Video Generation Demo")
+ print("=" * 60)
+
+ model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
+ config = TransformerWanModelConfig()
+
+ # For sharding (multi-GPU), uncomment:
+ # from jax.sharding import AxisType
+ # mesh = jax.make_mesh((2, 2), ("fsdp", "tp"), axis_types=(AxisType.Explicit, AxisType.Explicit))
+ # jax.set_mesh(mesh)
+
+ scheduler = FlaxUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ beta_start=0.0001,
+ beta_end=0.02,
+ beta_schedule="linear",
+ solver_order=2, # Order 2 for guided sampling
+ prediction_type="flow_prediction",
+ use_flow_sigmas=True, # Enable flow-based sigma schedule
+ flow_shift=3.0, # 5.0 for 720P, 3.0 for 480P
+ timestep_spacing="linspace",
+ predict_x0=True,
+ solver_type="bh2",
+ lower_order_final=True,
+ dtype=jnp.float32,
+ )
+ scheduler_state = scheduler.create_state()
+
+ if prompt is not None:
+ prompts = [prompt]
+ else:
+ prompts = [
+ "A curious racoon",
+ ]
+
+ print(f"\nPrompt: {prompts[0]}")
+ print(f"Model: Wan2.1-T2V-1.3B ({config.num_layers} layers, {config.hidden_dim} dim)")
+ print(f"Video: {config.num_frames} frames @ 480p")
+
+ tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
+ umt5_encoder = params.create_t5_encoder_from_safe_tensors(model_ckpt_path, mesh=None)
+
+ print("\n[1/4] Encoding text with UMT5...")
+ text_embeds = get_t5_text_embeddings(prompts[0], tokenizer, umt5_encoder, max_length=config.max_text_len)
+ if neg_prompt is not None:
+ negative_prompts = [neg_prompt]
+ else:
+ negative_prompts = ["blurry"]
+ negative_embeds = get_t5_text_embeddings(
+ negative_prompts[0], tokenizer, umt5_encoder, max_length=config.max_text_len
+ )
+
+ print("\n[2/5] Loading Diffusion Transformer weights...")
+ model = params.create_model_from_safe_tensors(model_ckpt_path, config, mesh=None)
+ print("\n[2.5/5] Loading VAE decoder...")
+ vae_decoder = params.create_vae_decoder_from_safe_tensors(model_ckpt_path, mesh=None)
+ print("Model loaded successfully")
+
+ print("\n[3/4] Generating video latents...")
+ print(f"Using {config.num_inference_steps} diffusion steps")
+ print(f"Guidance scale: {config.guidance_scale}")
+
+ key = jax.random.PRNGKey(42)
+ start_time = time.time()
+
+ latents = jax.random.normal(
+ key, (1, config.num_frames, config.latent_size[0], config.latent_size[1], config.latent_input_dim)
+ )
+
+ latents = generate_video(
+ model=model,
+ latents=latents,
+ text_embeds=text_embeds,
+ negative_embeds=negative_embeds,
+ num_steps=config.num_inference_steps,
+ guidance_scale=config.guidance_scale,
+ scheduler=scheduler,
+ scheduler_state=scheduler_state,
+ )
+
+ generation_time = time.time() - start_time
+ print(f"β Generated latents in {generation_time:.2f}s")
+ print(f"Latents shape: {latents.shape}")
+ print(latents[0, 1:, :, 25:, :].mean())
+ print(f"Has NaN: {jnp.isnan(latents).any()}")
+ print(f"Has Inf: {jnp.isinf(latents).any()}")
+
+ print("\n[4/5] Decoding latents to video...")
+ video = decode_video_latents(latents, vae_decoder)
+ generation_time = time.time() - start_time
+ print(f"Video shape: {video.shape}")
+ print(video[0, 1:, :, 235:, :].mean())
+
+ print("\n" + "=" * 60)
+ print("β Generation Complete!")
+ print("=" * 60)
+ print(f"Total time: {generation_time:.2f}s")
+ print(f"FPS: {config.num_frames / generation_time:.2f}")
+ vae_wan.save_video(video, "generated_video.mp4")
+ print("Video saved to generated_video.mp4")
+
+ return video
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Wan2.1-T2V-1.3B Text-to-Video Generation Demo")
+ parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation")
+ parser.add_argument("--neg_prompt", type=str, default=None, help="Negative text prompt for video generation")
+ args = parser.parse_args()
+ run_model(args.prompt, args.neg_prompt)
+
+
+if __name__ == "__main__":
+ main()
+
+__all__ = ["run_model"]
diff --git a/bonsai/models/wan2/scheduling_utils.py b/bonsai/models/wan2/scheduling_utils.py
new file mode 100644
index 00000000..da722600
--- /dev/null
+++ b/bonsai/models/wan2/scheduling_utils.py
@@ -0,0 +1,318 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import math
+import os
+from dataclasses import dataclass
+from enum import Enum
+from typing import ClassVar, Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+
+
+# NOTE: We make this type an enum because it simplifies usage in docs and prevents
+# circular imports when used for `_compatibles` within the schedulers module.
+# When it's used as a type in pipelines, it really is a Union because the actual
+# scheduler instance is passed in.
+class FlaxKarrasDiffusionSchedulers(Enum):
+ FlaxDDIMScheduler = 1
+ FlaxDDPMScheduler = 2
+ FlaxPNDMScheduler = 3
+ FlaxLMSDiscreteScheduler = 4
+ FlaxDPMSolverMultistepScheduler = 5
+
+
+@dataclass
+class FlaxSchedulerOutput:
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: jnp.ndarray
+
+
+class FlaxSchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+
+ Class attributes:
+ - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
+ `from_config` can be used from a class different than the one used to save the config (should be overridden
+ by parent class).
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
+ ignore_for_config: ClassVar = ["dtype"]
+ _compatibles: ClassVar = []
+ has_compatibles = True
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ subfolder: Optional[str] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a Scheduler class from a pre-defined JSON-file.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
+ e.g., `./my_model_directory/`.
+ subfolder (`str`, *optional*):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config, kwargs = cls.load_config(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ subfolder=subfolder,
+ return_unused_kwargs=True,
+ **kwargs,
+ )
+ scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
+
+ if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
+ state = scheduler.create_state()
+
+ if return_unused_kwargs:
+ return scheduler, state, unused_kwargs
+
+ return scheduler, state
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~FlaxSchedulerMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+
+ @property
+ def compatibles(self):
+ """
+ Returns all schedulers that are compatible with this scheduler
+
+ Returns:
+ `List[SchedulerMixin]`: List of compatible schedulers
+ """
+ return self._get_compatibles()
+
+ @classmethod
+ def _get_compatibles(cls):
+ compatible_classes_str = list(set([cls.__name__, *cls._compatibles]))
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+ compatible_classes = [
+ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
+ ]
+ return compatible_classes
+
+
+def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
+ assert len(shape) >= x.ndim
+ return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray:
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return jnp.array(betas, dtype=dtype)
+
+
+def rescale_betas_zero_snr(betas):
+ """
+ Rescales betas to have a zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+ """
+
+ alphas = 1.0 - betas
+ alphas_cumprod = jnp.cumprod(alphas, axis=0)
+ alphas_bar_sqrt = jnp.sqrt(alphas_cumprod)
+
+ # Store old values.
+ alphas_bar_sqrt_0 = jnp.copy(alphas_bar_sqrt[0])
+ alphas_bar_sqrt_T = jnp.copy(alphas_bar_sqrt[-1])
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = jnp.concatenate([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+@flax.struct.dataclass
+class CommonSchedulerState:
+ alphas: jnp.ndarray
+ betas: jnp.ndarray
+ alphas_cumprod: jnp.ndarray
+
+ @classmethod
+ def create(cls, scheduler):
+ config = scheduler.config
+
+ if config.trained_betas is not None:
+ betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype)
+ elif config.beta_schedule == "linear":
+ betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype)
+ elif config.beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ betas = (
+ jnp.linspace(
+ config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype
+ )
+ ** 2
+ )
+ elif config.beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype)
+ else:
+ raise NotImplementedError(
+ f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}"
+ )
+
+ if not config.rescale_zero_terminal_snr:
+ betas = rescale_betas_zero_snr(betas)
+
+ alphas = 1.0 - betas
+
+ alphas_cumprod = jnp.cumprod(alphas, axis=0)
+
+ return cls(
+ alphas=alphas,
+ betas=betas,
+ alphas_cumprod=alphas_cumprod,
+ )
+
+
+def get_sqrt_alpha_prod(
+ state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
+):
+ alphas_cumprod = state.alphas_cumprod
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
+
+ return sqrt_alpha_prod, sqrt_one_minus_alpha_prod
+
+
+def add_noise_common(
+ state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
+):
+ sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps)
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+
+def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray):
+ sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps)
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
diff --git a/bonsai/models/wan2/tests/__init__.py b/bonsai/models/wan2/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/bonsai/models/wan2/tests/test_transformer_output.py b/bonsai/models/wan2/tests/test_transformer_output.py
new file mode 100644
index 00000000..590e5a96
--- /dev/null
+++ b/bonsai/models/wan2/tests/test_transformer_output.py
@@ -0,0 +1,792 @@
+"""Test output correctness by comparing JAX implementation with HuggingFace reference."""
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from huggingface_hub import snapshot_download
+from bonsai.models.wan2 import params, transformer_wan
+import torch
+from diffusers import AutoModel
+from jax.lax import Precision
+from collections import OrderedDict
+
+def check_weight_loading(jax_model, torch_model):
+
+ text_proj_torch = torch_model.condition_embedder.text_embedder.linear_1.weight.detach().cpu().float().numpy().T
+ text_proj_jax = np.array(jax_model.text_proj.layers[0].kernel.value)
+ print("Text projection weights:")
+ print(f" Shapes: torch={text_proj_torch.shape}, jax={text_proj_jax.shape}")
+ print(f" Max diff: {np.abs(text_proj_torch - text_proj_jax).max():.2e}")
+ print(f" Mean diff: {np.abs(text_proj_torch - text_proj_jax).mean():.2e}")
+ # torch :(out, in, t, h, w)
+ torch_emb = torch_model.patch_embedding.weight.detach().cpu().float().numpy()
+ # jax: (t, h, w, in, out)
+ jax_emb = np.array(jax_model.patch_embed.kernel.value).transpose(4,3,0,1,2)
+
+ print("Embedding weights:")
+ print(f" Shapes: torch={torch_emb.shape}, jax={jax_emb.shape}")
+ print(f" Max diff: {np.abs(torch_emb - jax_emb).max():.2e}")
+ print(f" Mean diff: {np.abs(torch_emb - jax_emb).mean():.2e}")
+
+ # check fused kv projection weights in first block
+ torch_k_weight = torch_model.blocks[0].attn2.to_k.weight.detach().cpu().float().numpy().T
+ torch_v_weight = torch_model.blocks[0].attn2.to_v.weight.detach().cpu().float().numpy().T
+ torch_kv = np.concatenate([torch_k_weight, torch_v_weight], axis=1)
+
+ jax_kv_weight = np.array(jax_model.blocks[0].cross_attn.kv_proj.kernel.value)
+ print("First block cross-attention KV projection weights:")
+ print(f" Shapes: torch={torch_kv.shape}, jax={jax_kv_weight.shape}")
+ print(f" Max diff: {np.abs(torch_kv - jax_kv_weight).max():.2e}")
+ print(f" Mean diff: {np.abs(torch_kv - jax_kv_weight).mean():.2e}")
+
+def compare_outputs(jax_output: jax.Array, torch_output, name: str, rtol: float = 1e-2, atol: float = 1e-4):
+
+ print(f"\n{'=' * 80}")
+ print(f"Comparing: {name}")
+ print(f"{'=' * 80}")
+
+ print(f"before convert torch: {torch_output.dtype}, jax: {jax_output.dtype}")
+ if torch_output.dtype == torch.bfloat16:
+ torch_output = torch_output.float()
+ if jax_output.dtype == jnp.bfloat16:
+ jax_output = jax_output.astype(jnp.float32)
+
+ if isinstance(torch_output, torch.Tensor):
+ torch_np = torch_output.detach().cpu().numpy()
+ else:
+ torch_np = np.array(torch_output)
+
+ jax_np = np.array(jax_output)
+
+ print(f"JAX shape: {jax_np.shape}")
+ print(f"Torch shape: {torch_np.shape}")
+ print(f"JAX dtype: {jax_np.dtype}")
+ print(f"Torch dtype: {torch_np.dtype}")
+
+ if jax_np.shape != torch_np.shape:
+ print("Shape mismatch!")
+ return False
+
+ abs_diff = np.abs(jax_np - torch_np)
+ rel_diff = abs_diff / (np.abs(torch_np) + 1e-10)
+
+ max_abs_diff = np.max(abs_diff)
+ max_rel_diff = np.max(rel_diff)
+ mean_abs_diff = np.mean(abs_diff)
+ mean_rel_diff = np.mean(rel_diff)
+
+ print("\nStatistics:")
+ print(f" Max absolute difference: {max_abs_diff:.2e}")
+ print(f" Max relative difference: {max_rel_diff:.2e}")
+ print(f" Mean absolute difference: {mean_abs_diff:.2e}")
+ print(f" Mean relative difference: {mean_rel_diff:.2e}")
+
+ print(f"\nJAX output range: [{np.min(jax_np):.4f}, {np.max(jax_np):.4f}]")
+ print(f"Torch output range: [{np.min(torch_np):.4f}, {np.max(torch_np):.4f}]")
+
+ close = np.allclose(jax_np, torch_np, rtol=rtol, atol=atol)
+
+ if close:
+ print(f"\nβ
Outputs match within tolerance (rtol={rtol}, atol={atol})")
+ else:
+ print(f"\nβ Outputs do NOT match (rtol={rtol}, atol={atol})")
+ # Show some mismatched locations
+ mismatch_mask = ~np.isclose(jax_np, torch_np, rtol=rtol, atol=atol)
+ n_mismatches = np.sum(mismatch_mask)
+ print(f" Number of mismatches: {n_mismatches} / {jax_np.size} ({100 * n_mismatches / jax_np.size:.2f}%)")
+
+ return close
+
+# e2e test
+def test_dit_output():
+ print("\n" + "=" * 80)
+ print("TEST 2: DiT")
+ print("=" * 80)
+
+ model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
+ config = transformer_wan.TransformerWanModelConfig
+
+ print("\n[1/2] Loading transformer")
+ transformer = AutoModel.from_pretrained(model_ckpt_path, subfolder="transformer", torch_dtype=torch.bfloat16)
+
+ jax_dit = params.create_model_from_safe_tensors(model_ckpt_path, config,mesh=None)
+
+ batch_size = 1
+ num_channels = 16 # in_channels
+ num_frames = 9
+ height = 30
+ width = 30
+ text_seq_len = 128
+ text_dim = 4096 # UMT5 hidden dimension
+
+ # Create dummy inputs
+ hidden_states = torch.randn(
+ batch_size, num_channels, num_frames, height, width,
+ dtype=torch.float32
+ )
+ hidden_states_jax = jnp.array(np.transpose(hidden_states.numpy(), (0, 2, 3, 4, 1)))
+
+ timestep = torch.randint(
+ 0, 1000,
+ (batch_size,),
+ dtype=torch.long
+ )
+ timestep_jax = jnp.array(timestep.numpy())
+
+ encoder_hidden_states = torch.randn(
+ batch_size, text_seq_len, text_dim,
+ dtype=torch.float32
+ )
+ encoder_hidden_states_jax = jnp.array(encoder_hidden_states.numpy())
+
+ print("\n[2/2] Running forward pass")
+ with torch.no_grad():
+ output = transformer(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_image=None, # Only for I2V models
+ return_dict=True,
+ attention_kwargs=None,
+ )
+ pred_noise = jax_dit.forward(hidden_states_jax, encoder_hidden_states_jax, timestep_jax, deterministic=True)
+ expected_output = np.transpose(output.sample.numpy(), (0, 2, 3, 4, 1))
+
+ # Compare final output
+ return compare_outputs(pred_noise, expected_output, "Final DiT Output", rtol=1e-3, atol=1e-4)
+
+def test_dit():
+ print("\n" + "=" * 80)
+ print("TEST 2: DiT")
+ print("=" * 80)
+
+ model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
+ config = transformer_wan.TransformerWanModelConfig
+
+ print("\n[1/2] Loading transformer")
+ transformer = AutoModel.from_pretrained(model_ckpt_path, subfolder="transformer", torch_dtype=torch.bfloat16)
+
+ jax_dit = params.create_model_from_safe_tensors(model_ckpt_path, config,mesh=None)
+ print("transformer loaded:", transformer, transformer.config)
+
+ check_weight_loading(jax_dit, transformer)
+
+ batch_size = 1
+ num_channels = 16 # in_channels
+ num_frames = 9
+ height = 30
+ width = 30
+ text_seq_len = 128
+ text_dim = 4096 # UMT5 hidden dimension
+
+ debugger = WanTransformerDebugger(transformer)
+ debugger.register_hooks()
+ debugger_attn = WanAttentionDebugger(transformer)
+ debugger_attn.register_attention_hooks(block_indices=[0])
+
+ # Create dummy inputs
+ hidden_states = torch.randn(
+ batch_size, num_channels, num_frames, height, width,
+ dtype=torch.float32
+ )
+ # jax channels last
+ hidden_states_jax = jnp.array(np.transpose(hidden_states.numpy(), (0, 2, 3, 4, 1))).astype(jnp.bfloat16)
+ timestep = torch.randint(
+ 0, 1000,
+ (batch_size,),
+ dtype=torch.long
+ )
+ timestep_jax = jnp.array(timestep.numpy())
+ encoder_hidden_states = torch.randn(
+ batch_size, text_seq_len, text_dim,
+ dtype=torch.float32
+ )
+ encoder_hidden_states_jax = jnp.array(encoder_hidden_states.numpy()).astype(jnp.bfloat16)
+
+ print("\n[2/2] Running forward pass")
+ with torch.no_grad():
+ output = transformer(
+ hidden_states=hidden_states.to(dtype=torch.bfloat16),
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states.to(dtype=torch.bfloat16),
+ encoder_hidden_states_image=None, # Only for I2V models
+ return_dict=True,
+ attention_kwargs=None,
+ )
+
+ # 5. Get intermediate outputs
+ intermediate_outputs = debugger.get_outputs()
+ states = debugger_attn.get_attention_states()
+
+ print("=" * 80)
+ print("INTERMEDIATE OUTPUTS")
+ print("=" * 80)
+
+ for name, tensor in states.items():
+ print(f"{name:50s}: {tuple(tensor.shape)}")
+
+ # Restore original processors
+ debugger_attn.restore_processors()
+
+ # Manual forward pass with intermediate comparisons
+ print("\n" + "=" * 80)
+ print("STEP-BY-STEP FORWARD PASS WITH COMPARISONS")
+ print("=" * 80)
+
+ # 1. Text projection
+ text_embeds_jax = jax_dit.text_proj(encoder_hidden_states_jax)
+ text_embeds_torch = intermediate_outputs['condition_encoder_hidden_states']
+ compare_outputs(text_embeds_jax, text_embeds_torch, "Text Projection", rtol=1e-3, atol=1e-4)
+
+ # 2. Patch embedding
+ x_jax = jax_dit.patch_embed(hidden_states_jax)
+ # PyTorch is BCTHW, need to convert to BTHWC for comparison
+ patch_embed_torch = intermediate_outputs['patch_embed_output']
+ patch_embed_torch_channels_last = np.transpose(patch_embed_torch.float().numpy(), (0, 2, 3, 4, 1))
+ compare_outputs(x_jax, patch_embed_torch_channels_last, "Patch Embedding", rtol=1e-3, atol=1e-4)
+
+ # Reshape to sequence
+ b, t_out, h_out, w_out, d = x_jax.shape
+ x_jax = x_jax.reshape(b, t_out * h_out * w_out, d)
+ grid_sizes = (t_out, h_out, w_out)
+
+ # 3. RoPE frequencies
+ max_seq = max(grid_sizes)
+ rope_freqs = tuple(
+ jax.lax.stop_gradient(arr)
+ for arr in transformer_wan.precompute_freqs_cis_3d(
+ dim=jax_dit.cfg.head_dim,
+ theta=jax_dit.rope_theta,
+ max_seq_len=max_seq
+ )
+ )
+
+ # Build full RoPE frequency grid for comparison
+ freqs_t, freqs_h, freqs_w = rope_freqs
+ f, h, w = grid_sizes
+ head_dim = jax_dit.cfg.head_dim
+ dim_base = head_dim // 6
+ dim_t, dim_h, dim_w = head_dim - 4 * dim_base, 2 * dim_base, 2 * dim_base
+
+ freqs_grid = jnp.concatenate([
+ jnp.broadcast_to(freqs_t[:f, None, None, :, :], (f, h, w, dim_t // 2, 2)),
+ jnp.broadcast_to(freqs_h[None, :h, None, :, :], (f, h, w, dim_h // 2, 2)),
+ jnp.broadcast_to(freqs_w[None, None, :w, :, :], (f, h, w, dim_w // 2, 2)),
+ ], axis=3).reshape(t_out * h_out * w_out, head_dim // 2, 2)
+
+ rope_freqs_cos_jax = freqs_grid[..., 0]
+ rope_freqs_cos_jax = jnp.stack([rope_freqs_cos_jax, rope_freqs_cos_jax], axis=-1).reshape(1, -1, 1, head_dim)
+
+ # PyTorch RoPE freqs are in BCHW format, convert to sequence format
+ rope_freqs_cos_torch = intermediate_outputs['rope_freqs_cos']
+ compare_outputs(rope_freqs_cos_jax, rope_freqs_cos_torch, "RoPE Freqs Cos", rtol=1e-5, atol=1e-6)
+
+ # 4. Time embeddings
+ time_emb_jax, time_proj_jax = jax_dit.time_embed(timestep_jax)
+ time_emb_torch = intermediate_outputs['condition_temb']
+ time_proj_torch = intermediate_outputs['condition_timestep_proj']
+ compare_outputs(time_emb_jax, time_emb_torch, "Time Embedding", rtol=1e-3, atol=1e-4)
+ compare_outputs(time_proj_jax, time_proj_torch, "Time Projection", rtol=1e-3, atol=1e-4)
+
+ # 5. Process through transformer blocks with detailed attention comparison
+ for i, block in enumerate(jax_dit.blocks):
+ if i==0:
+ print(f"\n{'='*80}")
+ print(f"BLOCK {i} - DETAILED COMPARISON")
+ print(f"{'='*80}")
+
+ # Get modulation parameters
+ b_size = time_proj_jax.shape[0]
+ d = jax_dit.cfg.hidden_dim
+ reshaped_time_emb = time_proj_jax.reshape(b_size, 6, d)
+ modulation = reshaped_time_emb + block.scale_shift_table.value
+ modulation = modulation.reshape(b_size, -1)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(modulation, 6, axis=-1)
+
+ # Self-attention with detailed steps
+ print(f"\n--- Self-Attention ---")
+ norm_x = block.norm1(x_jax)
+ norm_x_modulated = norm_x * (1 + scale_msa[:, None, :]) + shift_msa[:, None, :]
+
+ attn_out = block.self_attn(norm_x_modulated, deterministic=True, rope_state=(rope_freqs, grid_sizes))
+
+ # # Q, K, V projections
+ num_heads = block.self_attn.num_heads
+ head_dim = block.self_attn.head_dim
+ b_size, n = norm_x_modulated.shape[:2]
+
+ # Compare with PyTorch
+ attn1_output_torch = intermediate_outputs[f'block_{i}_attn1_output']
+ compare_outputs(attn_out, attn1_output_torch, f"Block {i} Attn1 Output", rtol=1e-2, atol=1e-3)
+
+ # Apply gate and residual
+ x_jax = x_jax + gate_msa[:, None, :] * attn_out
+
+ # Cross-attention
+ print(f"\n--- Cross-Attention ---")
+ norm_x = block.norm2(x_jax)
+ compare_outputs(norm_x, intermediate_outputs[f'block_{i}_norm2_output'], f"Block {i} Norm2 Output", rtol=1e-3, atol=1e-4)
+ b, n, m = norm_x.shape[0], norm_x.shape[1], text_embeds_jax.shape[1]
+ q_norm = block.cross_attn.q_norm(block.cross_attn.q_proj(norm_x))
+ compare_outputs(q_norm, intermediate_outputs[f'block_{i}_attn2_query_normed'], f"Block {i} Attn2 Q after Norm", rtol=1e-5, atol=1e-6)
+ kv = block.cross_attn.kv_proj(text_embeds_jax)
+ k, v = jnp.split(kv, 2, axis=-1)
+ k_norm = block.cross_attn.k_norm(k)
+ compare_outputs(k_norm, intermediate_outputs[f'block_{i}_attn2_key_normed'], f"Block {i} Attn2 K after Norm", rtol=1e-5, atol=1e-6)
+
+ cross_out = block.cross_attn(norm_x, text_embeds_jax, deterministic=True)
+
+
+ attn2_output_torch = intermediate_outputs[f'block_{i}_attn2_output']
+ compare_outputs(cross_out, attn2_output_torch, f"Block {i} Attn2 Output", rtol=1e-2, atol=1e-3)
+
+ x_jax = x_jax + cross_out
+
+ # MLP
+ print(f"\n--- MLP ---")
+ norm_x = block.norm3(x_jax)
+ norm_x_modulated = norm_x * (1 + scale_mlp[:, None, :]) + shift_mlp[:, None, :]
+ mlp_out = block.mlp(norm_x_modulated)
+
+ ffn_output_torch = intermediate_outputs[f'block_{i}_ffn_output']
+ compare_outputs(mlp_out, ffn_output_torch, f"Block {i} FFN Output", rtol=1e-2, atol=1e-3)
+
+ x_jax = x_jax + gate_mlp[:, None, :] * mlp_out
+
+ # Compare final block output
+ block_output_torch = intermediate_outputs[f'block_{i}_output']
+ compare_outputs(x_jax, block_output_torch, f"Block {i} Final Output", rtol=1e-2, atol=1e-3)
+
+ if i > 0: # Only compare first block in detail
+ x_jax = block(x_jax, text_embeds_jax, time_proj_jax, deterministic=True, rope_state=(rope_freqs, grid_sizes))
+ compare_outputs(x_jax, intermediate_outputs[f'block_{i}_output'], f"Block {i} Output", rtol=1e-2, atol=1e-3)
+
+ # 6. Final layer
+ jax_dit_output = jax_dit.final_layer(x_jax, time_emb_jax)
+
+ compare_outputs(jax_dit_output, intermediate_outputs['proj_out_output'], "Final Projection Output", rtol=1e-3, atol=1e-4)
+
+ # Reshape to video format
+ jax_dit_output = jax_dit.unpatchify(jax_dit_output, grid_sizes)
+
+ # 4. Verify output shape
+ expected_shape = (batch_size, num_channels, num_frames, height, width)
+ assert output.sample.shape == expected_shape
+
+ # change to channels last for comparison
+ expected_output = np.transpose(output.sample.float().numpy(), (0, 2, 3, 4, 1))
+
+ debugger.remove_hooks()
+ # Compare final output
+ return compare_outputs(jax_dit_output, expected_output, "Final DiT Output", rtol=1e-3, atol=1e-4)
+class WanTransformerDebugger:
+ """Helper class to extract intermediate outputs from WanTransformer3DModel"""
+
+ def __init__(self, model):
+ self.model = model
+ self.intermediate_outputs = OrderedDict()
+ self.hooks = []
+
+ def register_hooks(self):
+ """Register forward hooks to capture intermediate outputs"""
+
+ # Hook for patch embedding
+ def patch_embed_hook(module, input, output):
+ self.intermediate_outputs['patch_embed_output'] = output.detach().cpu()
+
+ # Hook for rotary embeddings
+ def rope_hook(module, input, output):
+ freqs_cos, freqs_sin = output
+ self.intermediate_outputs['rope_freqs_cos'] = freqs_cos.detach().cpu()
+ self.intermediate_outputs['rope_freqs_sin'] = freqs_sin.detach().cpu()
+
+ # Hook for condition embedder
+ def condition_embedder_hook(module, input, output):
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = output
+ self.intermediate_outputs['condition_temb'] = temb.detach().cpu()
+ self.intermediate_outputs['condition_timestep_proj'] = timestep_proj.detach().cpu()
+ self.intermediate_outputs['condition_encoder_hidden_states'] = encoder_hidden_states.detach().cpu()
+ if encoder_hidden_states_image is not None:
+ self.intermediate_outputs['condition_encoder_hidden_states_image'] = encoder_hidden_states_image.detach().cpu()
+
+ # Hook for each transformer block
+ for i, block in enumerate(self.model.blocks):
+ def make_block_hook(block_idx):
+ def block_hook(module, input, output):
+ self.intermediate_outputs[f'block_{block_idx}_output'] = output.detach().cpu()
+ return block_hook
+
+ # Hook for block output
+ handle = block.register_forward_hook(make_block_hook(i))
+ self.hooks.append(handle)
+
+ # Hook for self-attention in each block
+ def make_attn1_hook(block_idx):
+ def attn1_hook(module, input, output):
+ self.intermediate_outputs[f'block_{block_idx}_attn1_output'] = output.detach().cpu()
+ return attn1_hook
+
+ handle = block.attn1.register_forward_hook(make_attn1_hook(i))
+ self.hooks.append(handle)
+
+ # Hook for cross-attention in each block
+ def make_attn2_hook(block_idx):
+ def attn2_hook(module, input, output):
+ self.intermediate_outputs[f'block_{block_idx}_attn2_output'] = output.detach().cpu()
+ return attn2_hook
+
+ def make_hook(name):
+ def hook(module, input, output):
+ if isinstance(output, tuple):
+ self.intermediate_outputs[name] = output[0].detach().cpu()
+ else:
+ self.intermediate_outputs[name] = output.detach().cpu()
+ return hook
+
+ handle = block.attn2.register_forward_hook(make_attn2_hook(i))
+ self.hooks.append(handle)
+ handle = block.norm2.register_forward_hook(make_hook(f'block_{i}_norm2_output'))
+ self.hooks.append(handle)
+
+ attn = block.attn2
+ # 1. Hook Q projection
+ h = attn.to_q.register_forward_hook(
+ make_hook(f'block_{i}_attn2_query')
+ )
+ self.hooks.append(h)
+
+ # 2. Hook K projection
+ h = attn.to_k.register_forward_hook(
+ make_hook(f'block_{i}_attn2_key')
+ )
+ self.hooks.append(h)
+
+ # 3. Hook V projection
+ h = attn.to_v.register_forward_hook(
+ make_hook(f'block_{i}_attn2_value')
+ )
+ self.hooks.append(h)
+
+ # 4. Hook Q norm
+ h = attn.norm_q.register_forward_hook(
+ make_hook(f'block_{i}_attn2_query_normed')
+ )
+ self.hooks.append(h)
+
+ # 5. Hook K norm
+ h = attn.norm_k.register_forward_hook(
+ make_hook(f'block_{i}_attn2_key_normed')
+ )
+ self.hooks.append(h)
+
+ # 6. Hook output projection
+ h = attn.to_out[0].register_forward_hook(
+ make_hook(f'block_{i}_attn2_output')
+ )
+ self.hooks.append(h)
+ h = attn.register_forward_hook(make_hook(f'block_{i}_attn2_attention'))
+ self.hooks.append(h)
+
+
+ # Hook for FFN in each block
+ def make_ffn_hook(block_idx):
+ def ffn_hook(module, input, output):
+ self.intermediate_outputs[f'block_{block_idx}_ffn_output'] = output.detach().cpu()
+ return ffn_hook
+
+ handle = block.ffn.register_forward_hook(make_ffn_hook(i))
+ self.hooks.append(handle)
+
+ # Hook for patch embedding
+ handle = self.model.patch_embedding.register_forward_hook(patch_embed_hook)
+ self.hooks.append(handle)
+
+ # Hook for rope
+ handle = self.model.rope.register_forward_hook(rope_hook)
+ self.hooks.append(handle)
+
+ # Hook for condition embedder
+ handle = self.model.condition_embedder.register_forward_hook(condition_embedder_hook)
+ self.hooks.append(handle)
+
+ # Hook for final norm
+ def norm_out_hook(module, input, output):
+ self.intermediate_outputs['norm_out_output'] = output.detach().cpu()
+
+ handle = self.model.norm_out.register_forward_hook(norm_out_hook)
+ self.hooks.append(handle)
+
+ # Hook for final projection
+ def proj_out_hook(module, input, output):
+ self.intermediate_outputs['proj_out_output'] = output.detach().cpu()
+
+ handle = self.model.proj_out.register_forward_hook(proj_out_hook)
+ self.hooks.append(handle)
+
+ def remove_hooks(self):
+ """Remove all registered hooks"""
+ for hook in self.hooks:
+ hook.remove()
+ self.hooks = []
+
+ def get_outputs(self):
+ """Get all captured intermediate outputs"""
+ return self.intermediate_outputs
+
+ def clear_outputs(self):
+ """Clear stored outputs"""
+ self.intermediate_outputs = OrderedDict()
+class WanAttentionDebugger:
+ """Capture internal attention states (Q, K, V, attention scores, etc.)"""
+
+ def __init__(self, model):
+ self.model = model
+ self.attention_states = OrderedDict()
+ self.hooks = []
+ self.original_processors = {}
+
+ def register_attention_hooks(self, block_indices=None):
+ """
+ Register hooks to capture attention internal states.
+
+ Args:
+ block_indices: List of block indices to hook, or None for all blocks
+ """
+ if block_indices is None:
+ block_indices = range(len(self.model.blocks))
+
+ for i in block_indices:
+ block = self.model.blocks[i]
+
+ # Hook self-attention (attn1)
+ self._hook_attention_module(block.attn1, f'block_{i}_attn1')
+
+ # Hook cross-attention (attn2)
+ self._hook_attention_module(block.attn2, f'block_{i}_attn2')
+
+ def _hook_attention_module(self, attn_module, prefix):
+ """Hook a single attention module to capture Q, K, V, and attention outputs"""
+
+ # Save original processor
+ self.original_processors[prefix] = attn_module.processor
+
+ # Create custom processor that captures intermediates
+ original_processor = attn_module.processor
+ attention_states = self.attention_states
+
+ class InstrumentedProcessor:
+ """Wrapper processor that captures intermediate values"""
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ rotary_emb=None,
+ **kwargs
+ ):
+ # Get encoder hidden states
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None and encoder_hidden_states is not None:
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+
+ # 1. Capture Q, K, V projections (before normalization)
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ attention_states[f'{prefix}_query_raw'] = query.detach().cpu()
+ attention_states[f'{prefix}_key_raw'] = key.detach().cpu()
+ attention_states[f'{prefix}_value_raw'] = value.detach().cpu()
+
+ # 2. Capture after normalization
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ attention_states[f'{prefix}_query_normed'] = query.detach().cpu()
+ attention_states[f'{prefix}_key_normed'] = key.detach().cpu()
+
+ # 3. Reshape to heads
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ attention_states[f'{prefix}_query_heads'] = query.detach().cpu()
+ attention_states[f'{prefix}_key_heads'] = key.detach().cpu()
+ attention_states[f'{prefix}_value_heads'] = value.detach().cpu()
+
+ # 4. Capture after RoPE (if applied)
+ if rotary_emb is not None:
+ def apply_rotary_emb(hidden_states, freqs_cos, freqs_sin):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
+
+ attention_states[f'{prefix}_query_rope'] = query.detach().cpu()
+ attention_states[f'{prefix}_key_rope'] = key.detach().cpu()
+
+ # 5. Handle I2V additional K, V
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+
+ key_img = attn.norm_added_k(key_img)
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
+
+ attention_states[f'{prefix}_key_img'] = key_img.detach().cpu()
+ attention_states[f'{prefix}_value_img'] = value_img.detach().cpu()
+
+ # Compute image attention (for I2V)
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
+ hidden_states_img = dispatch_attention_fn(
+ query, key_img, value_img,
+ attn_mask=None, dropout_p=0.0, is_causal=False,
+ backend=original_processor._attention_backend,
+ )
+ hidden_states_img = hidden_states_img.flatten(2, 3)
+
+ attention_states[f'{prefix}_img_attention_output'] = hidden_states_img.detach().cpu()
+
+ # 6. Compute main attention
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
+
+ # Note: We can't easily capture attention weights with dispatch_attention_fn
+ # because it uses optimized kernels (flash attention, etc.)
+ # For debugging attention weights, we'd need to use manual computation
+
+ hidden_states = dispatch_attention_fn(
+ query, key, value,
+ attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
+ backend=original_processor._attention_backend,
+ )
+
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ attention_states[f'{prefix}_attention_output'] = hidden_states.detach().cpu()
+
+ # 7. Combine with image attention if present
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+ attention_states[f'{prefix}_combined_output'] = hidden_states.detach().cpu()
+
+ # 8. Output projection
+ hidden_states = attn.to_out[0](hidden_states)
+ attention_states[f'{prefix}_output_proj'] = hidden_states.detach().cpu()
+
+ hidden_states = attn.to_out[1](hidden_states) # Dropout
+ attention_states[f'{prefix}_final_output'] = hidden_states.detach().cpu()
+
+ return hidden_states
+
+ # Replace processor
+ attn_module.set_processor(InstrumentedProcessor())
+
+ def register_attention_weight_hooks(self, block_indices=None):
+ """
+ Capture actual attention weights (scores).
+ Warning: This uses manual attention computation, not optimized kernels.
+ """
+ if block_indices is None:
+ block_indices = range(len(self.model.blocks))
+
+ for i in block_indices:
+ block = self.model.blocks[i]
+ self._hook_attention_with_weights(block.attn1, f'block_{i}_attn1')
+ self._hook_attention_with_weights(block.attn2, f'block_{i}_attn2')
+
+ def _hook_attention_with_weights(self, attn_module, prefix):
+ """Hook that computes attention manually to capture weights"""
+
+ attention_states = self.attention_states
+
+ class WeightCapturingProcessor:
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None,
+ attention_mask=None, rotary_emb=None, **kwargs):
+
+ # [Same Q, K, V projection code as above...]
+ # ... (omitted for brevity)
+
+ # Manual attention computation
+ import math
+
+ # Scaled dot-product attention
+ scale = 1.0 / math.sqrt(query.shape[-1])
+
+ # (B, seq_q, heads, head_dim) @ (B, seq_k, heads, head_dim).T
+ # -> (B, heads, seq_q, seq_k)
+ attn_weights = torch.einsum('bqhd,bkhd->bhqk', query, key) * scale
+
+ attention_states[f'{prefix}_attention_scores'] = attn_weights.detach().cpu()
+
+ # Apply mask if present
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ # Softmax
+ attn_weights = torch.softmax(attn_weights, dim=-1)
+
+ attention_states[f'{prefix}_attention_weights'] = attn_weights.detach().cpu()
+
+ # Apply attention to values
+ # (B, heads, seq_q, seq_k) @ (B, seq_k, heads, head_dim)
+ hidden_states = torch.einsum('bhqk,bkhd->bqhd', attn_weights, value)
+
+ # Continue with output projection...
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+ attn_module.set_processor(WeightCapturingProcessor())
+
+ def restore_processors(self):
+ """Restore original attention processors"""
+ for prefix, original_processor in self.original_processors.items():
+ # Parse prefix to get module
+ parts = prefix.split('_')
+ block_idx = int(parts[1])
+ attn_name = parts[2]
+
+ if attn_name == 'attn1':
+ self.model.blocks[block_idx].attn1.set_processor(original_processor)
+ elif attn_name == 'attn2':
+ self.model.blocks[block_idx].attn2.set_processor(original_processor)
+
+ def get_attention_states(self):
+ """Get all captured attention states"""
+ return self.attention_states
+
+ def clear_states(self):
+ """Clear captured states"""
+ self.attention_states = OrderedDict()
+
+if __name__ == "__main__":
+ # test_dit_output()
+ test_dit()
diff --git a/bonsai/models/wan2/tests/test_umt5_output.py b/bonsai/models/wan2/tests/test_umt5_output.py
new file mode 100644
index 00000000..bcf5991f
--- /dev/null
+++ b/bonsai/models/wan2/tests/test_umt5_output.py
@@ -0,0 +1,394 @@
+"""Test output correctness by comparing JAX implementation with HuggingFace reference."""
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from huggingface_hub import snapshot_download
+from bonsai.models.wan2 import params, umt5
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel, UMT5ForConditionalGeneration
+import os
+
+def check_weight_loading(jax_model, torch_model):
+ torch_emb = torch_model.shared.weight.detach().cpu().numpy()
+ jax_emb = np.array(jax_model.encoder.token_embedding.embedding.value)
+
+ print("Embedding weights:")
+ print(f" Shapes: torch={torch_emb.shape}, jax={jax_emb.shape}")
+ print(f" Max diff: {np.abs(torch_emb - jax_emb).max():.2e}")
+ print(f" Mean diff: {np.abs(torch_emb - jax_emb).mean():.2e}")
+
+ torch_q = torch_model.encoder.block[0].layer[0].SelfAttention.q.weight.detach().cpu().numpy()
+ jax_q = np.array(jax_model.encoder.blocks[0].attn.q.kernel.value)
+
+ print("\nFirst block query weight:")
+ print(f" Shapes: torch={torch_q.shape}, jax={jax_q.shape}")
+ print(f" Max diff: {np.abs(torch_q.T - jax_q).max():.2e}")
+ print(f" Mean diff: {np.abs(torch_q.T - jax_q).mean():.2e}")
+
+ torch_ln_weight = torch_model.encoder.final_layer_norm.weight.detach().cpu().numpy()
+ jax_ln_weight = np.array(jax_model.encoder.norm.weight.value)
+
+ print("\nFinal LayerNorm weight:")
+ print(f" Shapes: torch={torch_ln_weight.shape}, jax={jax_ln_weight.shape}")
+ print(f" Max diff: {np.abs(torch_ln_weight - jax_ln_weight).max():.2e}")
+ print(f" Mean diff: {np.abs(torch_ln_weight - jax_ln_weight).mean():.2e}")
+
+def compare_outputs(jax_output: jax.Array, torch_output, name: str, rtol: float = 1e-2, atol: float = 1e-4):
+ if torch_output.dtype == torch.bfloat16:
+ torch_output = torch_output.float()
+
+ # Convert PyTorch to numpy
+ if isinstance(torch_output, torch.Tensor):
+ torch_np = torch_output.detach().cpu().numpy()
+ else:
+ torch_np = np.array(torch_output)
+
+ # Convert JAX to numpy
+ jax_np = np.array(jax_output)
+
+ print(f"\n{'=' * 80}")
+ print(f"Comparing: {name}")
+ print(f"{'=' * 80}")
+ print(f"JAX shape: {jax_np.shape}")
+ print(f"Torch shape: {torch_np.shape}")
+ print(f"JAX dtype: {jax_np.dtype}")
+ print(f"Torch dtype: {torch_np.dtype}")
+
+ # Check shapes match
+ if jax_np.shape != torch_np.shape:
+ print("Shape mismatch!")
+ return False
+
+ # Compute differences
+ abs_diff = np.abs(jax_np - torch_np)
+ rel_diff = abs_diff / (np.abs(torch_np) + 1e-10)
+
+ max_abs_diff = np.max(abs_diff)
+ max_rel_diff = np.max(rel_diff)
+ mean_abs_diff = np.mean(abs_diff)
+ mean_rel_diff = np.mean(rel_diff)
+
+ print("\nStatistics:")
+ print(f" Max absolute difference: {max_abs_diff:.2e}")
+ print(f" Max relative difference: {max_rel_diff:.2e}")
+ print(f" Mean absolute difference: {mean_abs_diff:.2e}")
+ print(f" Mean relative difference: {mean_rel_diff:.2e}")
+
+ print(f"\nJAX output range: [{np.min(jax_np):.4f}, {np.max(jax_np):.4f}]")
+ print(f"Torch output range: [{np.min(torch_np):.4f}, {np.max(torch_np):.4f}]")
+
+ # Check if within tolerance
+ close = np.allclose(jax_np, torch_np, rtol=rtol, atol=atol)
+
+ if close:
+ print(f"\nβ
Outputs match within tolerance (rtol={rtol}, atol={atol})")
+ else:
+ print(f"\nβ Outputs do NOT match (rtol={rtol}, atol={atol})")
+ # Show some mismatched locations
+ mismatch_mask = ~np.isclose(jax_np, torch_np, rtol=rtol, atol=atol)
+ n_mismatches = np.sum(mismatch_mask)
+ print(f" Number of mismatches: {n_mismatches} / {jax_np.size} ({100 * n_mismatches / jax_np.size:.2f}%)")
+
+ return close
+
+def test_t5_encoder():
+ """Test UMT5 encoder output against Wan UMT5 reference implementation."""
+ print("\n" + "=" * 80)
+ print("TEST 1: UMT5 Encoder (UMT5-XXL)")
+ print("=" * 80)
+ # Download checkpoint
+ model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
+
+ # Test prompt
+ prompt = "A beautiful sunset over the ocean with waves crashing on the shore"
+ max_length = 512
+
+ print(f"\nTest prompt: {prompt}")
+ print(f"Max length: {max_length}")
+
+ print("\nTokenizing...")
+ tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
+ inputs_j = tokenizer(prompt, return_tensors="np")
+ inputs_p = tokenizer(prompt, return_tensors="pt")
+
+ print("\n[1/2] Loading UMT5 encoder...")
+ jax_t5 = params.create_t5_encoder_from_safe_tensors(model_ckpt_path, mesh=None)
+ hf_t5 = UMT5EncoderModel.from_pretrained(model_ckpt_path, subfolder="text_encoder", torch_dtype=torch.bfloat16)
+
+ check_weight_loading(jax_t5, hf_t5)
+
+ print("\nRunning model...")
+ input_ids_jax = jnp.array(inputs_j.input_ids)
+ jax_output = jax_t5(input_ids_jax, deterministic=True)
+
+ pytorch_output = hf_t5(inputs_p.input_ids)
+ torch_embeddings = pytorch_output.last_hidden_state
+
+ # Compare only the valid portion (ignore padding)
+ return compare_outputs(jax_output, torch_embeddings, "UMT5 Encoder Output", rtol=1e-3, atol=1e-4)
+
+
+def test_t5_intermediate():
+ """Compare intermediate layer outputs between JAX and PyTorch UMT5 encoder."""
+ print("\n" + "=" * 80)
+ print("TEST 2: UMT5 Encoder Intermediate Outputs")
+ print("=" * 80)
+
+ # Download checkpoint
+ model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
+
+ # Test prompt
+ prompt = "A beautiful sunset over the ocean with waves crashing on the shore"
+
+ print(f"\nTest prompt: {prompt}")
+
+ print("\nTokenizing...")
+ tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
+ inputs_j = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=512, truncation=True)
+ inputs_p = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=512, truncation=True)
+
+ print("\n[1/3] Loading models...")
+ jax_t5 = params.create_t5_encoder_from_safe_tensors(model_ckpt_path, mesh=None)
+ pytorch_t5 = UMT5EncoderModel.from_pretrained(
+ model_ckpt_path,
+ subfolder="text_encoder",
+ torch_dtype=torch.float32
+ )
+ pytorch_t5.eval()
+
+ # Register hooks to capture PyTorch intermediate outputs
+ print("\n[2/3] Capturing PyTorch intermediate outputs...")
+ pytorch_intermediates = {}
+
+ def make_hook(name):
+ def hook(module, input, output):
+ if isinstance(output, tuple):
+ pytorch_intermediates[name] = output[0].detach().cpu()
+ else:
+ pytorch_intermediates[name] = output.detach().cpu()
+ return hook
+
+ # Hook embedding
+ pytorch_t5.encoder.embed_tokens.register_forward_hook(make_hook("embeddings"))
+
+ # Hook each block with detailed attention captures
+ for i, block in enumerate(pytorch_t5.encoder.block):
+ block.register_forward_hook(make_hook(f"block_{i}_output"))
+ block.layer[0].register_forward_hook(make_hook(f"block_{i}_attn_output"))
+
+ # Hook attention Q, K, V projections
+ block.layer[0].layer_norm.register_forward_hook(make_hook(f"block_{i}_attn_norm"))
+ attn = block.layer[0].SelfAttention
+ attn.q.register_forward_hook(make_hook(f"block_{i}_q_proj"))
+ attn.k.register_forward_hook(make_hook(f"block_{i}_k_proj"))
+ attn.v.register_forward_hook(make_hook(f"block_{i}_v_proj"))
+
+ if len(block.layer) > 1:
+ block.layer[1].register_forward_hook(make_hook(f"block_{i}_ffn_output"))
+
+ # Hook final norm
+ pytorch_t5.encoder.final_layer_norm.register_forward_hook(make_hook("final_norm"))
+
+ # Run PyTorch forward
+ with torch.no_grad():
+ pytorch_output = pytorch_t5(inputs_p.input_ids)
+
+ print("\n[3/3] Comparing intermediate outputs...")
+
+ # Convert inputs to JAX
+ input_ids_jax = jnp.array(inputs_j.input_ids)
+
+ # Manual forward pass for JAX to get intermediates
+ # 1. Embeddings
+ x_jax = jax_t5.encoder.token_embedding(input_ids_jax)
+ embeddings_torch = pytorch_intermediates["embeddings"]
+ compare_outputs(x_jax, embeddings_torch, "Token Embeddings", rtol=1e-4, atol=1e-6)
+
+ # 2. Dropout (not comparing, just applying)
+ x_jax = jax_t5.encoder.dropout(x_jax, deterministic=True)
+
+ # 3. Position bias setup (only once)
+ batch_size, seq_len = input_ids_jax.shape
+ position_bias_jax = None
+
+ # 4. Process through blocks
+ num_layers = len(jax_t5.encoder.blocks)
+ print(f"\nComparing {num_layers} transformer blocks...")
+
+ for i in range(min(3, num_layers)): # Compare first 3 blocks
+ print(f"\n{'='*80}")
+ print(f"BLOCK {i}")
+ print(f"{'='*80}")
+
+ block = jax_t5.encoder.blocks[i]
+
+ # Manual attention computation to capture Q, K, V
+ print(f"\n--- Attention Details ---")
+
+ # Norm
+ normed_x_jax = block.norm1(x_jax)
+
+ compare_outputs(normed_x_jax, pytorch_intermediates[f"block_{i}_attn_norm"], f"Block {i} Attention Norm", rtol=1e-5, atol=1e-6)
+
+ # Q, K, V projections
+ q_jax = block.attn.q(normed_x_jax)
+ k_jax = block.attn.k(normed_x_jax)
+ v_jax = block.attn.v(normed_x_jax)
+
+ # Compare with PyTorch
+ q_torch = pytorch_intermediates[f"block_{i}_q_proj"]
+ k_torch = pytorch_intermediates[f"block_{i}_k_proj"]
+ v_torch = pytorch_intermediates[f"block_{i}_v_proj"]
+
+ compare_outputs(q_jax, q_torch, f"Block {i} Q after Linear", rtol=1e-5, atol=1e-6)
+ compare_outputs(k_jax, k_torch, f"Block {i} K after Linear", rtol=1e-5, atol=1e-6)
+ compare_outputs(v_jax, v_torch, f"Block {i} V after Linear", rtol=1e-5, atol=1e-6)
+
+ # Now run full attention (for comparison)
+ if position_bias_jax is None:
+ attn_output = block.attn(
+ x_jax,
+ mask=None,
+ pos_bias=None,
+ deterministic=True
+ )
+ # Compare position bias (only computed in first block)
+ if f"block_{i}_position_bias" in pytorch_intermediates:
+ pos_bias_torch = pytorch_intermediates[f"block_{i}_position_bias"]
+ compare_outputs(position_bias_jax, pos_bias_torch, f"Block {i} Position Bias", rtol=1e-5, atol=1e-6)
+ else:
+ attn_output = block.attn(
+ x_jax,
+ mask=None,
+ pos_bias=position_bias_jax,
+ deterministic=True
+ )
+
+ # PyTorch attention output (after layer[0] which includes norm + attn + dropout)
+ attn_torch = pytorch_intermediates[f"block_{i}_attn_output"]
+ compare_outputs(attn_output, attn_torch, f"Block {i} Attention Output", rtol=1e-3, atol=1e-5)
+
+ x_jax = attn_output
+
+ # FFN
+ if hasattr(block, 'ffn'):
+ ffn_output = block.ffn(x_jax, deterministic=True)
+ x_jax = ffn_output
+
+ if f"block_{i}_ffn_output" in pytorch_intermediates:
+ ffn_torch = pytorch_intermediates[f"block_{i}_ffn_output"]
+ compare_outputs(ffn_output, ffn_torch, f"Block {i} FFN Output", rtol=1e-3, atol=1e-5)
+
+ # Final block output
+ block_torch = pytorch_intermediates[f"block_{i}_output"]
+ compare_outputs(x_jax, block_torch, f"Block {i} Final Output", rtol=1e-3, atol=1e-5)
+
+ # 5. Final layer norm
+ x_jax = jax_t5.encoder.norm(x_jax)
+ final_norm_torch = pytorch_intermediates["final_norm"]
+ compare_outputs(x_jax, final_norm_torch, "Final Layer Norm", rtol=1e-4, atol=1e-6)
+
+ # 6. Final output
+ torch_final = pytorch_output.last_hidden_state
+ compare_outputs(x_jax, torch_final, "Final Output", rtol=1e-3, atol=1e-4)
+
+ print("\n" + "="*80)
+ print("Intermediate comparison complete!")
+ print("="*80)
+
+def test_t5_e2e():
+ """Test JAX UMT5 encoder with PyTorch decoder on end-to-end generation task."""
+ print("\n" + "=" * 80)
+ print("TEST 2: UMT5 E2E (JAX Encoder + PyTorch Decoder)")
+ print("=" * 80)
+ # Test prompts
+ test_prompts = [
+ "Studies have shown that good for you",
+ ]
+
+ print("\n[1/3] Loading models...")
+ tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
+ model_ckpt_path = snapshot_download("google/umt5-xxl")
+
+ # Load JAX encoder
+ jax_t5 = params.create_t5_encoder_from_safe_tensors(model_ckpt_path, mesh=None, is_sf=False,config=umt5.T5Config.umt5_xxl())
+
+ # Load full PyTorch model (encoder + decoder)
+ pytorch_full_model = UMT5ForConditionalGeneration.from_pretrained(
+ model_ckpt_path,
+ torch_dtype=torch.float32
+ )
+ pytorch_full_model.eval()
+
+ print("\n[2/3] Running generation tests...")
+
+ for i, prompt in enumerate(test_prompts):
+ print(f"\n{'='*80}")
+ print(f"Test Case {i+1}: {prompt}")
+ print(f"{'='*80}")
+
+ # Tokenize
+ inputs_jax = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=512, truncation=True)
+ inputs_torch = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=512, truncation=True)
+
+ # ============================================================
+ # Baseline: Full PyTorch model
+ # ============================================================
+ print("\n[Baseline] Full PyTorch model:")
+ with torch.no_grad():
+ pytorch_outputs = pytorch_full_model.generate(
+ input_ids=inputs_torch.input_ids,
+ attention_mask=inputs_torch.attention_mask,
+ max_length=50,
+ num_beams=1, # Greedy decoding
+ do_sample=False,
+ )
+
+ pytorch_text = tokenizer.decode(pytorch_outputs[0])
+ print(f" Output: {pytorch_text}")
+
+ # ============================================================
+ # Hybrid: JAX encoder + PyTorch decoder
+ # ============================================================
+ print("\n[Hybrid] JAX encoder + PyTorch decoder:")
+
+ # Get encoder hidden states from JAX
+ input_ids_jax = jnp.array(inputs_jax.input_ids)
+ jax_encoder_output = jax_t5(input_ids_jax, deterministic=True)
+
+ # Convert to PyTorch
+ encoder_hidden_states = torch.from_numpy(np.array(jax_encoder_output))
+
+ print(f" JAX encoder output shape: {encoder_hidden_states.shape}")
+ print(f" JAX encoder output range: [{encoder_hidden_states.min():.4f}, {encoder_hidden_states.max():.4f}]")
+
+ # Create encoder outputs object for decoder
+ from transformers.modeling_outputs import BaseModelOutput
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_hidden_states
+ )
+
+ # Generate using decoder with JAX encoder outputs
+ with torch.no_grad():
+ hybrid_outputs = pytorch_full_model.generate(
+ encoder_outputs=encoder_outputs,
+ attention_mask=inputs_torch.attention_mask,
+ max_length=50,
+ num_beams=1,
+ do_sample=False,
+ )
+
+ hybrid_text = tokenizer.decode(hybrid_outputs[0])
+ print(f" Output: {hybrid_text}")
+
+ print("\n[3/3] Summary")
+ print("=" * 80)
+ print("E2E test complete. Check outputs above to verify encoder correctness.")
+
+
+if __name__ == "__main__":
+ # Uncomment the test you want to run:
+ test_t5_encoder() # Test final outputs only
+ # test_t5_intermediate() # Test intermediate layer outputs (detailed)
+ test_t5_e2e() # End-to-end generation test
diff --git a/bonsai/models/wan2/tests/test_vae_output.py b/bonsai/models/wan2/tests/test_vae_output.py
new file mode 100644
index 00000000..2ef0333f
--- /dev/null
+++ b/bonsai/models/wan2/tests/test_vae_output.py
@@ -0,0 +1,426 @@
+import jax
+import jax.numpy as jnp
+import numpy as np
+from modelscope import snapshot_download as ms_snapshot_download
+from huggingface_hub import snapshot_download as hf_snapshot_download
+from bonsai.models.wan2 import params
+from bonsai.models.wan2 import vae_wan as vae_lib
+import torch
+from diffusers import AutoencoderKLWan
+from jax.lax import Precision
+from collections import OrderedDict
+from flax import nnx
+import sys
+import torch
+
+class WanVAEDecoderHooks:
+ """Extract intermediate outputs from Wan VAE Decoder using forward hooks"""
+
+ def __init__(self, vae):
+ self.vae = vae
+ self.decoder = vae.decoder
+ self.outputs = OrderedDict()
+ self.hooks = []
+ self.capture_enabled = True # Control when to capture
+
+ def register_decoder_hooks(self):
+ """Register hooks on all decoder layers"""
+
+ # 1. Hook post_quant_conv (before decoder)
+ h = self.vae.post_quant_conv.register_forward_hook(
+ self._make_conditional_hook('post_quant_conv')
+ )
+ self.hooks.append(h)
+
+ # 2. Hook conv_in
+ h = self.decoder.conv_in.register_forward_hook(
+ self._make_conditional_hook('conv_in')
+ )
+ self.hooks.append(h)
+
+ # 3. Hook mid_block
+ h = self.decoder.mid_block.register_forward_hook(
+ self._make_conditional_hook('mid_block')
+ )
+ self.hooks.append(h)
+
+ # 4. Hook each mid_block residual block
+ if hasattr(self.decoder.mid_block, 'resnets'):
+ for i, res_block in enumerate(self.decoder.mid_block.resnets):
+ h = res_block.register_forward_hook(
+ self._make_conditional_hook(f'mid_block_res_{i}')
+ )
+ self.hooks.append(h)
+ if hasattr(self.decoder.mid_block, 'attentions'):
+ for i, attn_block in enumerate(self.decoder.mid_block.attentions):
+ h = attn_block.register_forward_hook(
+ self._make_conditional_hook(f'mid_block_attn_{i}')
+ )
+ self.hooks.append(h)
+
+ # 5. Hook each up_block
+ for i, up_block in enumerate(self.decoder.up_blocks):
+ h = up_block.register_forward_hook(
+ self._make_conditional_hook(f'up_block_{i}')
+ )
+ self.hooks.append(h)
+
+ # Hook residual blocks within up_block
+ if hasattr(up_block, 'resnets'):
+ for j, res_block in enumerate(up_block.resnets):
+ h = res_block.register_forward_hook(
+ self._make_conditional_hook(f'up_block_{i}_res_{j}')
+ )
+ self.hooks.append(h)
+
+ # Hook upsample layers
+ if hasattr(up_block, 'upsamplers') and up_block.upsamplers is not None:
+ h = up_block.upsamplers[0].register_forward_hook(
+ self._make_conditional_hook(f'up_block_{i}_upsample')
+ )
+ self.hooks.append(h)
+
+ # 6. Hook norm_out
+ h = self.decoder.norm_out.register_forward_hook(
+ self._make_conditional_hook('norm_out')
+ )
+ self.hooks.append(h)
+
+ # 7. Hook nonlinearity (after norm_out)
+ # We need to hook this differently since it's a function, not a module
+ # We'll hook conv_out input instead
+
+ # 8. Hook conv_out (final output)
+ h = self.decoder.conv_out.register_forward_hook(
+ self._make_conditional_hook('conv_out')
+ )
+ self.hooks.append(h)
+
+ def _make_conditional_hook(self, name):
+ """Create hook that only captures when enabled"""
+ def hook(module, input, output):
+ if self.capture_enabled:
+ self.outputs[name] = output.detach().cpu()
+ return hook
+
+
+ def decode_first_frame_only(self, latents):
+ """Decode only first frame with hooks enabled"""
+
+ # Extract first frame
+ z_first = latents[:, :, 0:1, :, :] # (B, C, 1, H, W)
+
+ # Enable capture
+ self.capture_enabled = True
+
+ # Run through decoder directly (bypass frame loop)
+ with torch.no_grad():
+ x = self.vae.post_quant_conv(z_first)
+ out = self.decoder(
+ x,
+ feat_cache=None, # No cache for single frame
+ feat_idx=[0],
+ first_chunk=True
+ )
+
+ # Disable capture
+ self.capture_enabled = False
+
+ return out
+
+
+ def _make_hook(self, name):
+ """Create a hook function with closure over name"""
+ def hook(module, input, output):
+ self.outputs[name] = output.detach().cpu()
+ return hook
+
+ def remove_hooks(self):
+ """Remove all registered hooks"""
+ for h in self.hooks:
+ h.remove()
+ self.hooks = []
+
+ def clear_outputs(self):
+ """Clear captured outputs"""
+ self.outputs = OrderedDict()
+
+ def get_outputs(self):
+ """Get all captured outputs"""
+ return self.outputs
+
+def test_vae_decoder_outpout(src:str="hf"):
+ # Load VAE model
+ print("Loading AutoencoderKLWan...")
+ if src == "ms":
+ model_ckpt_path = ms_snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers",allow_patterns='vae/*')
+ elif src == "hf":
+ model_ckpt_path = hf_snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers",allow_patterns='vae/*')
+
+ vae_jax = params.create_vae_decoder_from_safe_tensors(model_ckpt_path, mesh=None)
+ vae = AutoencoderKLWan.from_pretrained(
+ model_ckpt_path,
+ subfolder="vae",
+ torch_dtype=torch.float32
+ )
+ vae.eval()
+
+ # Register hooks
+ hook_manager = WanVAEDecoderHooks(vae)
+ hook_manager.register_decoder_hooks()
+
+ # Create dummy latent input
+ torch.manual_seed(42)
+ batch_size = 1
+ z_dim = 16
+ num_frames = 9
+ height = 30 # After spatial compression (8x)
+ width = 52
+
+ latents_mean = (
+ torch.tensor(vae_lib.VAEConfig.latent_mean)
+ .view(1, 16, 1, 1, 1)
+ .to(dtype=torch.float32)
+ )
+ latents_std = 1.0 / torch.tensor(vae_lib.VAEConfig.latent_std).view(
+ 1, 16, 1, 1, 1
+ ).to(dtype=torch.float32)
+
+ latents_original = torch.randn(batch_size, z_dim, num_frames, height, width,dtype=torch.float32)
+ latents = latents_original / latents_std + latents_mean
+ latents_jax = jnp.array(latents_original.numpy().transpose(0,2,3,4,1))
+
+ print(f"\nInput latents shape: {latents.shape}")
+ print("Running decoder forward pass...\n")
+
+ decoded_jax = vae_jax.decode(latents_jax)
+ print(decoded_jax[0,1:,:,235:,:].mean())
+ with torch.no_grad():
+ decoded = vae.decode(latents).sample
+ compare_outputs(decoded_jax, decoded, "final_output", rtol=1e-2, atol=1e-4)
+ hook_manager.remove_hooks()
+
+def test_vae_decoder(src:str="hf"):
+ # Load VAE model
+ print("Loading AutoencoderKLWan...")
+ if src == "ms":
+ model_ckpt_path = ms_snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers",allow_patterns='vae/*')
+ elif src == "hf":
+ model_ckpt_path = hf_snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers",allow_patterns='vae/*')
+
+ vae_jax = params.create_vae_decoder_from_safe_tensors(model_ckpt_path, mesh=None)
+ vae = AutoencoderKLWan.from_pretrained(
+ model_ckpt_path,
+ subfolder="vae",
+ torch_dtype=torch.float32
+ )
+ vae.eval()
+
+ # Register hooks
+ hook_manager = WanVAEDecoderHooks(vae)
+ hook_manager.register_decoder_hooks()
+
+ # Create dummy latent input
+ torch.manual_seed(42)
+ batch_size = 1
+ z_dim = 16
+ num_frames = 9
+ height = 30 # After spatial compression (8x)
+ width = 52
+
+ latents_mean = (
+ torch.tensor(vae_lib.VAEConfig.latent_mean)
+ .view(1, 16, 1, 1, 1)
+ .to(dtype=torch.float32)
+ )
+ latents_std = 1.0 / torch.tensor(vae_lib.VAEConfig.latent_std).view(
+ 1, 16, 1, 1, 1
+ ).to(dtype=torch.float32)
+
+ latents_original = torch.randn(batch_size, z_dim, num_frames, height, width,dtype=torch.float32)
+ latents = latents_original / latents_std + latents_mean
+ latents_jax = jnp.array(latents_original.numpy().transpose(0,2,3,4,1))
+
+ print(f"\nInput latents shape: {latents.shape}")
+ print("Running decoder forward pass...\n")
+
+ print("=" * 80)
+ print("CAPTURED VAE DECODER INTERMEDIATE OUTPUTS")
+ print("=" * 80)
+
+ output_jax = {}
+ z, _ = vae_jax.conv2(latents_jax, None)
+ compare_outputs(z, outputs['post_quant_conv'], 'post_quant_conv', rtol=1e-2, atol=1e-4)
+ output_jax['post_quant_conv'] = z
+
+ t = z.shape[1]
+ frames = []
+ decoder = vae_jax.decoder
+
+ # Initialize cache list for feature caching
+ cache_list = [None] * 50
+
+ for i in range(t):
+ print(f"\n{'='*80}")
+ print(f"Processing frame {i+1}/{t}")
+ print(f"{'='*80}")
+
+ # Reset cache index for each frame
+ cache_idx = [0]
+ frame_latent = z[:, i : i + 1, :, :, :]
+
+ if i == 0:
+ idx = cache_idx[0]
+ x, cache_list[idx] = decoder.conv_in(frame_latent, cache_list[idx])
+ cache_idx[0] += 1
+ compare_outputs(x, outputs['conv_in'], 'conv_in', rtol=1e-2, atol=1e-4)
+ output_jax['conv_in'] = x
+
+ # Mid block 1
+ x, cache_list = decoder.mid_block1(x, cache_list, cache_idx)
+ compare_outputs(x, outputs['mid_block_res_0'], 'mid_block_res_0', rtol=1e-2, atol=1e-4)
+ output_jax['mid_block_res_0'] = x
+
+ # Mid attention
+ x = decoder.mid_attn(x)
+ compare_outputs(x, outputs['mid_block_attn_0'], 'mid_block_attn_0', rtol=1e-2, atol=1e-4)
+ output_jax['mid_block_attn_0'] = x
+
+ # Mid block 2
+ x, cache_list = decoder.mid_block2(x, cache_list, cache_idx)
+ compare_outputs(x, outputs['mid_block_res_1'], 'mid_block_res_1', rtol=1e-2, atol=1e-4)
+ output_jax['mid_block_res_1'] = x
+
+ # Upsample stage 0
+ for j, block in enumerate(decoder.up_blocks_0):
+ x, cache_list = block(x, cache_list, cache_idx)
+ compare_outputs(x, outputs[f'up_block_0_res_{j}'], f'up_block_0_res_{j}', rtol=1e-2, atol=1e-4)
+ output_jax[f'up_block_0_res_{j}'] = x
+ x, cache_list = decoder.up_sample_0(x, cache_list, cache_idx)
+ compare_outputs(x, outputs['up_block_0_upsample'], 'up_block_0_upsample', rtol=1e-2, atol=1e-4)
+ output_jax['up_block_0_upsample'] = x
+
+ # Upsample stage 1
+ for j, block in enumerate(decoder.up_blocks_1):
+ x, cache_list = block(x, cache_list, cache_idx)
+ compare_outputs(x, outputs[f'up_block_1_res_{j}'], f'up_block_1_res_{j}', rtol=1e-2, atol=1e-4)
+ output_jax[f'up_block_1_res_{j}'] = x
+ x, cache_list = decoder.up_sample_1(x, cache_list, cache_idx)
+ compare_outputs(x, outputs['up_block_1_upsample'], 'up_block_1_upsample', rtol=1e-2, atol=1e-4)
+ output_jax['up_block_1_upsample'] = x
+
+ # Upsample stage 2 (spatial only, no cache)
+ for j, block in enumerate(decoder.up_blocks_2):
+ x, cache_list = block(x, cache_list, cache_idx)
+ compare_outputs(x, outputs[f'up_block_2_res_{j}'], f'up_block_2_res_{j}', rtol=1e-2, atol=1e-4)
+ output_jax[f'up_block_2_res_{j}'] = x
+ x = decoder.up_sample_2(x) # Spatial-only, no cache
+ compare_outputs(x, outputs['up_block_2_upsample'], 'up_block_2_upsample', rtol=1e-2, atol=1e-4)
+ output_jax['up_block_2_upsample'] = x
+
+ # Upsample stage 3 (no spatial upsample)
+ for j, block in enumerate(decoder.up_blocks_3):
+ x, cache_list = block(x, cache_list, cache_idx)
+ compare_outputs(x, outputs[f'up_block_3_res_{j}'], f'up_block_3_res_{j}', rtol=1e-2, atol=1e-4)
+ output_jax[f'up_block_3_res_{j}'] = x
+
+ x = decoder.norm_out(x)
+ compare_outputs(x, outputs['norm_out'], 'norm_out', rtol=1e-2, atol=1e-4)
+ output_jax['norm_out'] = x
+ x = nnx.silu(x)
+
+ idx = cache_idx[0]
+ x, cache_list[idx] = decoder.conv_out(x, cache_list[idx])
+ cache_idx[0] += 1
+ compare_outputs(x, outputs['conv_out'], 'conv_out', rtol=1e-2, atol=1e-4)
+ output_jax['conv_out'] = x
+ frames.append(x)
+ else:
+ # Subsequent frames: use cached features
+ print(f"Subsequent frame: using feature cache")
+ frame_out, cache_list = decoder(frame_latent, cache_list, cache_idx)
+ frames.append(frame_out)
+
+ print("\n" + "=" * 80)
+ print(f"Final decoded output shape: {decoded.shape}")
+ print("=" * 80)
+
+ # Save outputs for comparison
+ outputs_dict = {
+ 'inputs': {
+ 'latents': latents.cpu(),
+ },
+ 'intermediate': outputs,
+ 'output': decoded.cpu(),
+ }
+
+ outputs_dict_jax = {
+ 'intermediate': output_jax
+ }
+
+ compare_with_jax_decoder(outputs_dict, outputs_dict_jax)
+
+ torch.save(outputs_dict, 'wan_vae_decoder_outputs.pt')
+ print("\nβ Saved outputs to 'wan_vae_decoder_outputs.pt'")
+
+ # Clean up
+ hook_manager.remove_hooks()
+
+def compare_outputs(jax_output: jax.Array, torch_output, name: str, rtol: float = 1e-2, atol: float = 1e-4):
+ if torch_output.dtype == torch.bfloat16:
+ torch_output = torch_output.float()
+
+ if isinstance(torch_output, torch.Tensor):
+ torch_np = torch_output.detach().cpu().numpy()
+ else:
+ torch_np = np.array(torch_output)
+
+ jax_np = np.array(jax_output).transpose(0,4,1,2,3) # Convert JAX [B,T,H,W,C] to [B,C,T,H,W]
+
+ print(f"\n{'=' * 80}")
+ print(f"Comparing: {name}")
+ print(f"{'=' * 80}")
+ print(f"JAX shape: {jax_np.shape}")
+ print(f"Torch shape: {torch_np.shape}")
+ print(f"JAX dtype: {jax_np.dtype}")
+ print(f"Torch dtype: {torch_np.dtype}")
+
+ if jax_np.shape != torch_np.shape:
+ print("Shape mismatch!")
+ return False
+
+ abs_diff = np.abs(jax_np - torch_np)
+ rel_diff = abs_diff / (np.abs(torch_np) + 1e-10)
+
+ max_abs_diff = np.max(abs_diff)
+ max_rel_diff = np.max(rel_diff)
+ mean_abs_diff = np.mean(abs_diff)
+ mean_rel_diff = np.mean(rel_diff)
+
+ print("\nStatistics:")
+ print(f" Max absolute difference: {max_abs_diff:.2e}")
+ print(f" Max relative difference: {max_rel_diff:.2e}")
+ print(f" Mean absolute difference: {mean_abs_diff:.2e}")
+ print(f" Mean relative difference: {mean_rel_diff:.2e}")
+
+ print(f"\nJAX output range: [{np.min(jax_np):.4f}, {np.max(jax_np):.4f}]")
+ print(f"Torch output range: [{np.min(torch_np):.4f}, {np.max(torch_np):.4f}]")
+
+ close = np.allclose(jax_np, torch_np, rtol=rtol, atol=atol)
+
+ if close:
+ print(f"\nβ
Outputs match within tolerance (rtol={rtol}, atol={atol})")
+ else:
+ print(f"\nβ Outputs do NOT match (rtol={rtol}, atol={atol})")
+ # Show some mismatched locations
+ mismatch_mask = ~np.isclose(jax_np, torch_np, rtol=rtol, atol=atol)
+ n_mismatches = np.sum(mismatch_mask)
+ print(f" Number of mismatches: {n_mismatches} / {jax_np.size} ({100 * n_mismatches / jax_np.size:.2f}%)")
+
+ return close
+
+if __name__ == "__main__":
+ # add args for modelscope/huggingface model download
+ src = sys.argv[1] if len(sys.argv) > 1 else "hf"
+ assert src in ["ms", "hf"], "Invalid source specified. Use 'modelscope' or 'huggingface'."
+ test_vae_decoder_output(src)
\ No newline at end of file
diff --git a/bonsai/models/wan2/transformer_wan.py b/bonsai/models/wan2/transformer_wan.py
new file mode 100644
index 00000000..8c9e204a
--- /dev/null
+++ b/bonsai/models/wan2/transformer_wan.py
@@ -0,0 +1,461 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Wan2.1-T2V-1.3B: Text-to-Video Diffusion Transformer Model.
+
+This implements the Wan2.1-T2V-1.3B model, a 1.3B parameter diffusion transformer
+for text-to-video generation using Flow Matching framework.
+
+Architecture:
+- 30-layer Diffusion Transformer with 1536 hidden dim
+- 12 attention heads (128 dim each)
+- Vision self-attention + text cross-attention
+- AdaLN modulation conditioned on timestep
+- UMT5 text encoder for multilingual prompts
+- Wan-VAE for video encoding/decoding
+"""
+
+import dataclasses
+import math
+from typing import Optional, Tuple
+
+import jax
+import jax.numpy as jnp
+from flax import nnx
+from jax.lax import Precision
+from jaxtyping import Array
+
+from .unipc_multistep_scheduler import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
+
+
+@dataclasses.dataclass(frozen=True)
+class TransformerWanModelConfig:
+ """Configuration for Wan2.1-T2V-1.3B Diffusion Transformer."""
+
+ weights_dtype: jnp.dtype = jnp.bfloat16
+ num_layers: int = 30
+ hidden_dim: int = 1536
+ latent_input_dim: int = 16
+ latent_output_dim: int = 16
+ ffn_dim: int = 8960
+ freq_dim: int = 256
+ num_heads: int = 12
+ head_dim: int = 128
+ text_embed_dim: int = 4096
+ max_text_len: int = 512
+ num_frames: int = 21
+ latent_size: Tuple[int, int] = (30, 30)
+ num_inference_steps: int = 50
+ guidance_scale: float = 5.0
+ patch_size: Tuple[int, int, int] = (1, 2, 2)
+ cross_attn_norm: bool = True
+ qk_norm: Optional[str] = "rms_norm_across_heads"
+ eps: float = 1e-6
+ added_kv_proj_dim: Optional[int] = None # None for T2V, set for I2V
+ rope_max_seq_len: int = 1024
+
+ def __post_init__(self):
+ assert self.hidden_dim == self.num_heads * self.head_dim, "hidden_dim must equal num_heads * head_dim"
+
+
+class TimestepEmbedding(nnx.Module):
+ """Timestep embedding: sinusoidal -> MLP -> projection for AdaLN."""
+
+ def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs):
+ self.cfg = cfg
+ self.time_embedding = nnx.Sequential(
+ nnx.Linear(cfg.freq_dim, cfg.hidden_dim, rngs=rngs),
+ nnx.silu,
+ nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs),
+ )
+ self.time_projection = nnx.Sequential(
+ nnx.silu,
+ nnx.Linear(cfg.hidden_dim, 6 * cfg.hidden_dim, rngs=rngs),
+ )
+
+ def __call__(self, t: Array) -> tuple[Array, Array]:
+ t_freq = sinusoidal_embedding_1d(t, self.cfg.freq_dim)
+ time_emb = self.time_embedding(t_freq)
+ time_proj = self.time_projection(time_emb)
+ return time_emb, time_proj
+
+
+def sinusoidal_embedding_1d(timesteps: Array, embedding_dim: int, max_period: int = 10000) -> Array:
+ half_dim = embedding_dim // 2
+ freqs = jnp.exp(-math.log(max_period) * jnp.arange(0, half_dim) / half_dim)
+ args = timesteps[:, None] * freqs[None, :]
+ embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
+ if embedding_dim % 2:
+ embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
+ return embedding
+
+
+def precompute_freqs_cis_3d(dim: int, theta: float = 10000.0, max_seq_len: int = 1024) -> tuple[Array, Array, Array]:
+ """Precompute 3D RoPE frequencies split as T: dim-4*(dim//6), H: 2*(dim//6), W: 2*(dim//6)."""
+ dim_base = dim // 6
+ dim_t, dim_h, dim_w = dim - 4 * dim_base, 2 * dim_base, 2 * dim_base
+ assert dim_t + dim_h + dim_w == dim
+ return (
+ rope_params(max_seq_len, dim_t, theta),
+ rope_params(max_seq_len, dim_h, theta),
+ rope_params(max_seq_len, dim_w, theta),
+ )
+
+
+def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> Array:
+ freqs = 1.0 / jnp.power(theta, jnp.arange(0, dim, 2, dtype=jnp.float32) / dim)
+ positions = jnp.arange(max_seq_len, dtype=jnp.float32)
+ freqs = jnp.outer(positions, freqs)
+ return jnp.stack([jnp.cos(freqs), jnp.sin(freqs)], axis=-1)
+
+
+def rope_apply(x: Array, grid_sizes: tuple[int, int, int], freqs: tuple[Array, Array, Array]) -> Array:
+ """Apply 3D RoPE to input tensor."""
+ b, seq_len, num_heads, head_dim = x.shape
+ f, h, w = grid_sizes
+ freqs_t, freqs_h, freqs_w = freqs
+
+ dim_base = head_dim // 6
+ dim_t, dim_h, dim_w = head_dim - 4 * dim_base, 2 * dim_base, 2 * dim_base
+
+ freqs_grid = jnp.concatenate(
+ [
+ jnp.broadcast_to(freqs_t[:f, None, None, :, :], (f, h, w, dim_t // 2, 2)),
+ jnp.broadcast_to(freqs_h[None, :h, None, :, :], (f, h, w, dim_h // 2, 2)),
+ jnp.broadcast_to(freqs_w[None, None, :w, :, :], (f, h, w, dim_w // 2, 2)),
+ ],
+ axis=3,
+ ).reshape(seq_len, head_dim // 2, 2)[None, :, None, :, :]
+
+ x_complex = x.reshape(b, seq_len, num_heads, head_dim // 2, 2)
+ x_out = jnp.stack(
+ [
+ x_complex[..., 0] * freqs_grid[..., 0] - x_complex[..., 1] * freqs_grid[..., 1],
+ x_complex[..., 0] * freqs_grid[..., 1] + x_complex[..., 1] * freqs_grid[..., 0],
+ ],
+ axis=-1,
+ )
+
+ return x_out.reshape(b, seq_len, num_heads, head_dim)
+
+
+class WanLayerNorm(nnx.LayerNorm):
+ """LayerNorm with float32 conversion for numerical stability."""
+
+ def __init__(self, dim: int, eps: float = 1e-6, use_scale: bool = False, use_bias: bool = False, *, rngs: nnx.Rngs):
+ super().__init__(dim, epsilon=eps, use_scale=use_scale, use_bias=use_bias, rngs=rngs)
+
+ def __call__(self, x: Array) -> Array:
+ dtype = x.dtype
+ return super().__call__(x.astype(jnp.float32)).astype(dtype)
+
+
+class MultiHeadAttention(nnx.Module):
+ def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs):
+ self.num_heads, self.head_dim = cfg.num_heads, cfg.head_dim
+ self.q_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST)
+ self.k_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST)
+ self.v_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST)
+ self.out_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST)
+ self.q_norm = nnx.RMSNorm(cfg.hidden_dim, rngs=rngs)
+ self.k_norm = nnx.RMSNorm(cfg.hidden_dim, rngs=rngs)
+
+ def __call__(self, x: Array, rope_state: tuple | None = None, deterministic: bool = True) -> Array:
+ b, n = x.shape[:2]
+ q = self.q_norm(self.q_proj(x)).reshape(b, n, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
+ k = self.k_norm(self.k_proj(x)).reshape(b, n, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
+ v = self.v_proj(x).reshape(b, n, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
+
+ if rope_state is not None:
+ freqs, grid_sizes = rope_state
+ q, k = jnp.transpose(q, (0, 2, 1, 3)), jnp.transpose(k, (0, 2, 1, 3))
+ q, k = rope_apply(q, grid_sizes, freqs), rope_apply(k, grid_sizes, freqs)
+ q, k = jnp.transpose(q, (0, 2, 1, 3)), jnp.transpose(k, (0, 2, 1, 3))
+
+ attn = jax.nn.softmax(
+ jnp.einsum("bhid,bhjd->bhij", q, k, precision=Precision.HIGHEST) / math.sqrt(self.head_dim), axis=-1
+ )
+ out = (
+ jnp.einsum("bhij,bhjd->bhid", attn, v, precision=Precision.HIGHEST).transpose(0, 2, 1, 3).reshape(b, n, -1)
+ )
+ return self.out_proj(out)
+
+
+class CrossAttention(nnx.Module):
+ def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs):
+ self.num_heads, self.head_dim = cfg.num_heads, cfg.head_dim
+ self.q_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST)
+ self.kv_proj = nnx.Linear(cfg.hidden_dim, 2 * cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST)
+ self.out_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST)
+ self.q_norm = nnx.RMSNorm(cfg.hidden_dim, rngs=rngs)
+ self.k_norm = nnx.RMSNorm(cfg.hidden_dim, rngs=rngs)
+
+ def __call__(self, x: Array, context: Array, deterministic: bool = True) -> Array:
+ b, n, m = x.shape[0], x.shape[1], context.shape[1]
+ q = self.q_norm(self.q_proj(x)).reshape(b, n, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
+ kv = self.kv_proj(context)
+ k, v = jnp.split(kv, 2, axis=-1)
+ k = self.k_norm(k).reshape(b, m, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
+ v = v.reshape(b, m, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
+ attn = jax.nn.softmax(
+ jnp.einsum("bhid,bhjd->bhij", q, k, precision=Precision.HIGHEST) / math.sqrt(self.head_dim), axis=-1
+ )
+ out = (
+ jnp.einsum("bhij,bhjd->bhid", attn, v, precision=Precision.HIGHEST).transpose(0, 2, 1, 3).reshape(b, n, -1)
+ )
+ return self.out_proj(out)
+
+
+def modulate(x: Array, shift: Array, scale: Array) -> Array:
+ """Apply adaptive layer norm modulation."""
+ original_dtype = x.dtype
+
+ return (x.astype(jnp.float32) * (1 + scale) + shift).astype(original_dtype)
+
+
+class WanAttentionBlock(nnx.Module):
+ """
+ Wan Diffusion Transformer Block.
+
+ Includes:
+ - Vision self-attention with AdaLN modulation
+ - Text-to-vision cross-attention
+ - Feed-forward network with AdaLN modulation
+ """
+
+ def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs):
+ self.cfg = cfg
+
+ self.norm1 = WanLayerNorm(cfg.hidden_dim, rngs=rngs)
+ self.norm2 = WanLayerNorm(cfg.hidden_dim, rngs=rngs, use_scale=True, use_bias=True)
+ self.norm3 = WanLayerNorm(cfg.hidden_dim, rngs=rngs)
+
+ self.self_attn = MultiHeadAttention(cfg, rngs=rngs)
+ self.cross_attn = CrossAttention(cfg, rngs=rngs)
+
+ self.mlp = nnx.Sequential(
+ nnx.Linear(cfg.hidden_dim, cfg.ffn_dim, rngs=rngs, precision=Precision.HIGHEST),
+ nnx.gelu,
+ nnx.Linear(cfg.ffn_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST),
+ )
+
+ self.scale_shift_table = nnx.Param(
+ jax.random.normal(rngs.params(), (1, 6, cfg.hidden_dim)) / (cfg.hidden_dim**0.5)
+ )
+
+ @jax.named_scope("wan_attention_block")
+ def __call__(
+ self,
+ x: Array,
+ text_embeds: Array,
+ time_proj: Array,
+ rope_state: tuple | None = None,
+ deterministic: bool = True,
+ ) -> Array:
+ """
+ Args:
+ x: [B, N, D] video tokens
+ text_embeds: [B, M, text_dim] text embeddings
+ time_proj: [B, 6*D] time embedding
+ rope_state: Optional tuple of (freqs, grid_sizes) for 3D RoPE
+ deterministic: Whether to apply dropout
+ Returns:
+ [B, N, D] transformed tokens
+ """
+ # Get modulation from time embedding
+ b = time_proj.shape[0]
+ d = self.cfg.hidden_dim
+ reshaped_time_proj = time_proj.reshape(b, 6, d)
+ modulation = reshaped_time_proj.astype(jnp.float32) + self.scale_shift_table.value
+ modulation = modulation.reshape(b, -1) # [B, 6*D]
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(modulation, 6, axis=-1)
+
+ # Self-attention with AdaLN modulation and RoPE
+ norm_x = self.norm1(x)
+ norm_x = modulate(norm_x, shift_msa[:, None, :], scale_msa[:, None, :])
+ attn_out = self.self_attn(norm_x, rope_state=rope_state, deterministic=deterministic)
+ x = (x.astype(jnp.float32) + gate_msa[:, None, :] * attn_out).astype(x.dtype)
+
+ # Cross-attention
+ norm_x = self.norm2(x)
+ cross_out = self.cross_attn(norm_x, text_embeds, deterministic=deterministic)
+ x = x + cross_out
+
+ # MLP with AdaLN modulation
+ norm_x = self.norm3(x)
+ norm_x = modulate(norm_x, shift_mlp[:, None, :], scale_mlp[:, None, :])
+ mlp_out = self.mlp(norm_x)
+ x = (x.astype(jnp.float32) + gate_mlp[:, None, :] * mlp_out).astype(jnp.float32)
+
+ return x
+
+
+class FinalLayer(nnx.Module):
+ """Final layer that predicts noise from DiT output."""
+
+ def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs):
+ self.cfg = cfg
+ self.norm = WanLayerNorm(cfg.hidden_dim, rngs=rngs)
+ out_dim = math.prod(cfg.patch_size) * cfg.latent_output_dim # expand out_dim here for unpatchify
+ self.linear = nnx.Linear(cfg.hidden_dim, out_dim, rngs=rngs, precision=Precision.HIGHEST)
+
+ self.scale_shift_table = nnx.Param(
+ jax.random.normal(rngs.params(), (1, 2, cfg.hidden_dim)) / (cfg.hidden_dim**0.5)
+ )
+
+ @jax.named_scope("final_layer")
+ def __call__(self, x: Array, time_emb: Array) -> Array:
+ """
+ Args:
+ x: [B, N, D] DiT output
+ time_emb: [B, D] time embedding from TimestepEmbedding
+ Returns:
+ [B, N, latent_output_dim] predicted noise
+ """
+ # [B, D] β [B, 1, D] + [1, 2, D] β [B, 2, D]
+ e = self.scale_shift_table.value + time_emb[:, None, :]
+ shift, scale = e[:, 0, :], e[:, 1, :]
+
+ x = modulate(x, shift[:, None, :], scale[:, None, :])
+ x = self.linear(x)
+ return x
+
+
+class Wan2DiT(nnx.Module):
+ """
+ Wan2.1-T2V-1.3B Diffusion Transformer.
+ """
+
+ def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs):
+ self.cfg = cfg
+
+ # 3D Conv to patchify video latents
+ # (T, H, W) β (T, H/2, W/2)
+ self.patch_embed = nnx.Conv(
+ in_features=cfg.latent_input_dim,
+ out_features=cfg.hidden_dim,
+ kernel_size=(1, 2, 2),
+ strides=(1, 2, 2),
+ padding="VALID",
+ use_bias=True,
+ rngs=rngs,
+ precision=Precision.HIGHEST,
+ )
+
+ # Text embedding projection: UMT5 (4096) β DiT (1536)
+ self.text_proj = nnx.Sequential(
+ nnx.Linear(cfg.text_embed_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST),
+ nnx.gelu,
+ nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST),
+ )
+
+ self.time_embed = TimestepEmbedding(cfg, rngs=rngs)
+
+ self.blocks = nnx.List([WanAttentionBlock(cfg, rngs=rngs) for _ in range(cfg.num_layers)])
+
+ self.final_layer = FinalLayer(cfg, rngs=rngs)
+
+ @jax.named_scope("wan2_dit")
+ @jax.jit
+ def forward(self, latents: Array, text_embeds: Array, timestep: Array, deterministic: bool = True) -> Array:
+ """
+ Forward pass of the Diffusion Transformer.
+
+ Args:
+ latents: [B, T, H, W, C] noisy video latents from VAE
+ text_embeds: [B, seq_len, 4096] from UMT5-XXL encoder (before projection)
+ timestep: [B] diffusion timestep (0 to num_steps)
+ deterministic: Whether to apply dropout
+
+ Returns:
+ predicted_noise: [B, T, H, W, C] predicted noise
+ """
+ text_embeds = self.text_proj(text_embeds)
+
+ # Get time embeddings
+ # time_emb: [B, D] for FinalLayer
+ # time_proj: [B, 6*D] for AdaLN in blocks
+ time_emb, time_proj = self.time_embed(timestep)
+
+ x = self.patch_embed(latents)
+ b, t_out, h_out, w_out, d = x.shape
+ x = x.reshape(b, t_out * h_out * w_out, d)
+
+ grid_sizes = (t_out, h_out, w_out)
+
+ max_seq = max(grid_sizes)
+ rope_freqs = tuple(
+ jax.lax.stop_gradient(arr) for arr in precompute_freqs_cis_3d(dim=self.cfg.head_dim, max_seq_len=max_seq)
+ )
+
+ for block in self.blocks:
+ x = block(x, text_embeds, time_proj, rope_state=(rope_freqs, grid_sizes), deterministic=deterministic)
+
+ # Final projection to noise space
+ x = self.final_layer(x, time_emb) # [B, T*H*W, latent_output_dim]
+
+ # Reshape back to video format
+ predicted_noise = self.unpatchify(x, grid_sizes)
+
+ return predicted_noise
+
+ def unpatchify(self, x: Array, grid_sizes: tuple[int, int, int]) -> Array:
+ """
+ Reconstruct video tensors from patch embeddings.
+
+ Args:
+ x: [B, T*H*W, patch_t*patch_h*patch_w*C] flattened patch embeddings
+ grid_sizes: (T_patches, H_patches, W_patches) grid dimensions
+
+ Returns:
+ [B, T, H, W, C] reconstructed video tensor (channel-last)
+ """
+ b, seq_len, feature_dim = x.shape
+ t_patches, h_patches, w_patches = grid_sizes
+ c = self.cfg.latent_output_dim
+ patch_size = self.cfg.patch_size
+
+ assert seq_len == t_patches * h_patches * w_patches, (
+ f"expected: seq_len={seq_len} should be {t_patches * h_patches * w_patches}"
+ )
+ assert feature_dim == patch_size[0] * patch_size[1] * patch_size[2] * c, (
+ f"expected: feature_dim={feature_dim} should be {patch_size[0] * patch_size[1] * patch_size[2] * c}"
+ )
+
+ x = x.reshape(
+ b,
+ t_patches,
+ h_patches,
+ w_patches,
+ patch_size[0],
+ patch_size[1],
+ patch_size[2],
+ c,
+ )
+ x = jnp.einsum("bthwpqrc->btphqwrc", x)
+ x = x.reshape(
+ b,
+ t_patches * patch_size[0],
+ h_patches * patch_size[1],
+ w_patches * patch_size[2],
+ c,
+ )
+
+ return x
+
+
+__all__ = [
+ "TransformerWanModelConfig",
+ "Wan2DiT",
+]
diff --git a/bonsai/models/wan2/umt5.py b/bonsai/models/wan2/umt5.py
new file mode 100644
index 00000000..97eaf55a
--- /dev/null
+++ b/bonsai/models/wan2/umt5.py
@@ -0,0 +1,380 @@
+import dataclasses
+import math
+from typing import Optional
+
+import jax
+import jax.numpy as jnp
+from flax import nnx
+from jax.lax import Precision
+from jaxtyping import Array
+
+
+def gelu(x: Array) -> Array:
+ return 0.5 * x * (1.0 + jnp.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.pow(x, 3.0))))
+
+
+class RMSNorm(nnx.Module):
+ def __init__(self, dim: int, eps: float = 1e-6, *, rngs: nnx.Rngs):
+ self.dim = dim
+ self.eps = eps
+ self.weight = nnx.Param(jnp.ones(dim))
+
+ def __call__(self, x: Array) -> Array:
+ # RMS normalization: x / sqrt(mean(x^2))
+ variance = jnp.mean(x.astype(jnp.float32) ** 2, axis=-1, keepdims=True)
+ x = x * jax.lax.rsqrt(variance + self.eps)
+ return self.weight.value * x
+
+
+class T5Attention(nnx.Module):
+ """UMT5 Multi-head attention."""
+
+ def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.1, *, rngs: nnx.Rngs):
+ assert dim_attn % num_heads == 0
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.num_heads = num_heads
+ self.head_dim = dim_attn // num_heads
+
+ # Linear projections
+ self.q = nnx.Linear(dim, dim_attn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs)
+ self.k = nnx.Linear(dim, dim_attn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs)
+ self.v = nnx.Linear(dim, dim_attn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs)
+ self.o = nnx.Linear(dim_attn, dim, use_bias=False, precision=Precision.HIGHEST, rngs=rngs)
+ self.dropout = nnx.Dropout(dropout, rngs=rngs)
+
+ def __call__(
+ self,
+ x: Array,
+ context: Optional[Array] = None,
+ mask: Optional[Array] = None,
+ pos_bias: Optional[Array] = None,
+ deterministic: bool = True,
+ ) -> Array:
+ """
+ Args:
+ x: [B, L1, C] query
+ context: [B, L2, C] key/value context, defaults to x for self-attention
+ mask: [B, L2] or [B, L1, L2] attention mask
+ pos_bias: [B, num_heads, L1, L2] position bias
+ """
+ context = x if context is None else context
+ b = x.shape[0]
+ n = self.num_heads
+ c = self.head_dim
+
+ q = self.q(x).reshape(b, -1, n, c)
+ k = self.k(context).reshape(b, -1, n, c)
+ v = self.v(context).reshape(b, -1, n, c)
+
+ # Attention bias
+ attn_bias = jnp.zeros((b, n, q.shape[1], k.shape[1]))
+ if pos_bias is not None:
+ attn_bias = attn_bias + pos_bias
+ if mask is not None:
+ # Expand mask to attention shape
+ if mask.ndim == 2:
+ mask = mask[:, None, None, :] # [B, 1, 1, L2]
+ else:
+ mask = mask[:, None, :, :] # [B, 1, L1, L2]
+ attn_bias = jnp.where(mask == 0, jnp.finfo(x.dtype).min, attn_bias)
+
+ attn = jnp.einsum("binc,bjnc->bnij", q, k) + attn_bias
+ attn = jax.nn.softmax(attn, axis=-1)
+ x = jnp.einsum("bnij,bjnc->binc", attn, v)
+
+ x = x.reshape(b, -1, n * c)
+ x = self.o(x)
+ x = self.dropout(x, deterministic=deterministic)
+ return x
+
+
+class T5FeedForward(nnx.Module):
+ """UMT5 Feed-forward network with gated activation."""
+
+ def __init__(self, dim: int, dim_ffn: int, dropout: float = 0.1, *, rngs: nnx.Rngs):
+ self.dim = dim
+ self.dim_ffn = dim_ffn
+
+ self.gate = nnx.Linear(dim, dim_ffn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs)
+ self.fc1 = nnx.Linear(dim, dim_ffn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs)
+ self.fc2 = nnx.Linear(dim_ffn, dim, use_bias=False, precision=Precision.HIGHEST, rngs=rngs)
+ self.dropout = nnx.Dropout(dropout, rngs=rngs)
+
+ def __call__(self, x: Array, deterministic: bool = True) -> Array:
+ x = self.fc1(x) * gelu(self.gate(x))
+ x = self.dropout(x, deterministic=deterministic)
+ x = self.fc2(x)
+ x = self.dropout(x, deterministic=deterministic)
+ return x
+
+
+class T5RelativeEmbedding(nnx.Module):
+ def __init__(self, num_buckets: int, num_heads: int, bidirectional: bool, max_dist: int = 128, *, rngs: nnx.Rngs):
+ self.num_buckets = num_buckets
+ self.num_heads = num_heads
+ self.bidirectional = bidirectional
+ self.max_dist = max_dist
+ self.embedding = nnx.Embed(num_buckets, num_heads, rngs=rngs)
+
+ def __call__(self, lq: int, lk: int) -> Array:
+ """Compute relative position bias.
+
+ Args:
+ lq: Query sequence length
+ lk: Key sequence length
+
+ Returns:
+ [1, num_heads, lq, lk] relative position bias
+ """
+ q_pos = jnp.arange(lq)[:, None]
+ k_pos = jnp.arange(lk)[None, :]
+ rel_pos = k_pos - q_pos
+
+ rel_pos_buckets = self._relative_position_bucket(rel_pos)
+
+ # Get embeddings
+ rel_pos_embeds = self.embedding(rel_pos_buckets) # [lq, lk, num_heads]
+ rel_pos_embeds = rel_pos_embeds.transpose(2, 0, 1)[None, :, :, :] # [1, num_heads, lq, lk]
+ return rel_pos_embeds
+
+ def _relative_position_bucket(self, rel_pos: Array) -> Array:
+ """Convert relative positions to bucket indices."""
+ if self.bidirectional:
+ num_buckets = self.num_buckets // 2
+ rel_buckets = (rel_pos > 0).astype(jnp.int32) * num_buckets
+ rel_pos = jnp.abs(rel_pos)
+ else:
+ num_buckets = self.num_buckets
+ rel_buckets = 0
+ rel_pos = -jnp.minimum(rel_pos, jnp.zeros_like(rel_pos))
+
+ # Small vs large positions
+ max_exact = num_buckets // 2
+ is_small = rel_pos < max_exact
+
+ # Logarithmic bucketing for large positions
+ rel_pos_large = max_exact + (
+ jnp.log(rel_pos.astype(jnp.float32) / max_exact)
+ / math.log(self.max_dist / max_exact)
+ * (num_buckets - max_exact)
+ ).astype(jnp.int32)
+ rel_pos_large = jnp.minimum(rel_pos_large, num_buckets - 1)
+
+ rel_buckets = rel_buckets + jnp.where(is_small, rel_pos, rel_pos_large)
+ return rel_buckets
+
+
+class T5SelfAttention(nnx.Module):
+ """UMT5 Self-attention block with feed-forward."""
+
+ def __init__(
+ self,
+ dim: int,
+ dim_attn: int,
+ dim_ffn: int,
+ num_heads: int,
+ num_buckets: int,
+ shared_pos: bool = True,
+ dropout: float = 0.1,
+ *,
+ rngs: nnx.Rngs,
+ ):
+ self.shared_pos = shared_pos
+
+ self.norm1 = RMSNorm(dim, rngs=rngs)
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout, rngs=rngs)
+ self.norm2 = RMSNorm(dim, rngs=rngs)
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout, rngs=rngs)
+
+ if not shared_pos:
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, rngs=rngs)
+ else:
+ self.pos_embedding = None
+
+ def __call__(
+ self, x: Array, mask: Optional[Array] = None, pos_bias: Optional[Array] = None, deterministic: bool = True
+ ) -> Array:
+ # Get position bias
+ if self.shared_pos:
+ e = pos_bias
+ else:
+ e = self.pos_embedding(x.shape[1], x.shape[1])
+
+ x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e, deterministic=deterministic)
+ x = x + self.ffn(self.norm2(x), deterministic=deterministic)
+ return x
+
+
+class T5Encoder(nnx.Module):
+ def __init__(
+ self,
+ vocab_size: int,
+ dim: int,
+ dim_attn: int,
+ dim_ffn: int,
+ num_heads: int,
+ num_layers: int,
+ num_buckets: int,
+ shared_pos: bool = True,
+ dropout: float = 0.1,
+ *,
+ rngs: nnx.Rngs,
+ ):
+ self.dim = dim
+ self.shared_pos = shared_pos
+
+ self.token_embedding = nnx.Embed(vocab_size, dim, rngs=rngs)
+ if shared_pos:
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, rngs=rngs)
+ else:
+ self.pos_embedding = None
+ self.dropout = nnx.Dropout(dropout, rngs=rngs)
+ self.blocks = nnx.List(
+ [
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, rngs=rngs)
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm = RMSNorm(dim, rngs=rngs)
+
+ def __call__(self, ids: Array, mask: Optional[Array] = None, deterministic: bool = True) -> Array:
+ x = self.token_embedding(ids)
+ x = self.dropout(x, deterministic=deterministic)
+
+ # Compute shared position bias if needed
+ e = self.pos_embedding(x.shape[1], x.shape[1]) if self.shared_pos else None
+
+ # Apply transformer blocks
+ for block in self.blocks:
+ x = block(x, mask, pos_bias=e, deterministic=deterministic)
+
+ x = self.norm(x)
+ x = self.dropout(x, deterministic=deterministic)
+ return x
+
+
+@dataclasses.dataclass(frozen=True)
+class T5Config:
+ """Configuration for UMT5 Encoder."""
+
+ vocab_size: int = 256384
+ dim: int = 4096
+ dim_attn: int = 4096
+ dim_ffn: int = 10240
+ num_heads: int = 64
+ num_layers: int = 24
+ num_buckets: int = 32
+ shared_pos: bool = False # UMT5 uses per-layer position embeddings
+ dropout: float = 0.1
+
+ @classmethod
+ def umt5_xxl(cls) -> "T5Config":
+ """UMT5-XXL configuration (~5B parameters)."""
+ return cls(
+ vocab_size=256384,
+ dim=4096,
+ dim_attn=4096,
+ dim_ffn=10240,
+ num_heads=64,
+ num_layers=24,
+ num_buckets=32,
+ shared_pos=False,
+ dropout=0.1,
+ )
+
+ @classmethod
+ def umt5_base(cls) -> "T5Config":
+ """UMT5-Base configuration (~580M parameters)."""
+ return cls(
+ vocab_size=256384,
+ dim=768,
+ dim_attn=768,
+ dim_ffn=2048,
+ num_heads=12,
+ num_layers=12,
+ num_buckets=32,
+ shared_pos=False,
+ dropout=0.1,
+ )
+
+
+class T5EncoderModel(nnx.Module):
+ """UMT5 Encoder-only model for text encoding.
+
+ Supports multiple UMT5 configurations (UMT5-XXL, UMT5-Base).
+ """
+
+ def __init__(self, config: T5Config, *, rngs: nnx.Rngs):
+ """Initialize UMT5 encoder from config.
+
+ Args:
+ config: T5Config specifying model architecture
+ rngs: Random number generators for initialization
+ """
+ self.config = config
+ self.encoder = T5Encoder(
+ vocab_size=config.vocab_size,
+ dim=config.dim,
+ dim_attn=config.dim_attn,
+ dim_ffn=config.dim_ffn,
+ num_heads=config.num_heads,
+ num_layers=config.num_layers,
+ num_buckets=config.num_buckets,
+ shared_pos=config.shared_pos,
+ dropout=config.dropout,
+ rngs=rngs,
+ )
+
+ @classmethod
+ def from_config(cls, config: T5Config, *, rngs: nnx.Rngs) -> "T5EncoderModel":
+ """Create UMT5 encoder from configuration.
+
+ Args:
+ config: T5Config instance
+ rngs: Random number generators
+
+ Returns:
+ T5EncoderModel instance
+ """
+ return cls(config, rngs=rngs)
+
+ @classmethod
+ def umt5_xxl(cls, *, rngs: nnx.Rngs) -> "T5EncoderModel":
+ """Create UMT5-XXL encoder (~5B parameters).
+
+ Args:
+ rngs: Random number generators
+
+ Returns:
+ T5EncoderModel configured as UMT5-XXL
+ """
+ return cls(T5Config.umt5_xxl(), rngs=rngs)
+
+ @classmethod
+ def umt5_base(cls, *, rngs: nnx.Rngs) -> "T5EncoderModel":
+ """Create UMT5-Base encoder (~580M parameters).
+
+ Args:
+ rngs: Random number generators
+
+ Returns:
+ T5EncoderModel configured as UMT5-Base
+ """
+ return cls(T5Config.umt5_base(), rngs=rngs)
+
+ def __call__(self, input_ids: Array, attention_mask: Optional[Array] = None, deterministic: bool = True) -> Array:
+ """Encode text.
+
+ Args:
+ input_ids: [B, L] token IDs
+ attention_mask: [B, L] attention mask (1 for valid tokens, 0 for padding)
+ deterministic: whether to disable dropout (True for inference)
+
+ Returns:
+ [B, L, dim] encoded text embeddings (dim depends on config)
+ """
+ return self.encoder(input_ids, mask=attention_mask, deterministic=deterministic)
+
+
+__all__ = ["T5Config", "T5Encoder", "T5EncoderModel"]
diff --git a/bonsai/models/wan2/unipc_multistep_scheduler.py b/bonsai/models/wan2/unipc_multistep_scheduler.py
new file mode 100644
index 00000000..0a42076a
--- /dev/null
+++ b/bonsai/models/wan2/unipc_multistep_scheduler.py
@@ -0,0 +1,811 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: reference pytorch implementation: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_unipc_multistep.py
+
+from functools import partial
+from types import SimpleNamespace
+from typing import List, Optional, Tuple, Union
+
+import flax
+import jax
+import jax.numpy as jnp
+
+from .scheduling_utils import (
+ CommonSchedulerState,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ add_noise_common,
+)
+
+
+@flax.struct.dataclass
+class UniPCMultistepSchedulerState:
+ """
+ Data class to hold the mutable state of the FlaxUniPCMultistepScheduler.
+ """
+
+ common: CommonSchedulerState
+
+ # Core schedule parameters (derived from CommonSchedulerState in create_state)
+ sigmas: jnp.ndarray
+ alpha_t: jnp.ndarray
+ sigma_t: jnp.ndarray
+ lambda_t: jnp.ndarray
+ init_noise_sigma: float
+
+ # History buffers for multi-step solver
+ # `model_outputs` stores previous converted model outputs (e.g., predicted x0 or epsilon)
+ timesteps: jnp.ndarray = None
+ model_outputs: jnp.ndarray = None
+ timestep_list: jnp.ndarray = None # Stores corresponding timesteps for `model_outputs`
+
+ # State variables for tracking progress and solver order
+ lower_order_nums: int = 0
+ last_sample: Optional[jnp.ndarray] = None # Sample from the previous predictor step
+ step_index: Optional[int] = None
+ begin_index: Optional[int] = None # Used for img2img/inpaing
+ this_order: int = 0 # Current effective order of the UniPC solver for this step
+
+ @classmethod
+ def create(
+ cls,
+ common_state: CommonSchedulerState,
+ alpha_t: jnp.ndarray,
+ sigma_t: jnp.ndarray,
+ lambda_t: jnp.ndarray,
+ sigmas: jnp.ndarray,
+ init_noise_sigma: jnp.ndarray,
+ ):
+ return cls(
+ common=common_state,
+ alpha_t=alpha_t,
+ sigma_t=sigma_t,
+ lambda_t=lambda_t,
+ sigmas=sigmas,
+ init_noise_sigma=init_noise_sigma,
+ lower_order_nums=0,
+ last_sample=None,
+ step_index=None,
+ begin_index=None,
+ this_order=0,
+ )
+
+
+class FlaxUniPCMultistepScheduler(FlaxSchedulerMixin):
+ """
+ `FlaxUniPCMultistepScheduler` is a JAX/Flax training-free framework designed for the fast sampling of diffusion models.
+ It implements the UniPC (Unified Predictor-Corrector) algorithm for efficient diffusion model sampling.
+ """
+
+ dtype: jnp.dtype
+
+ @property
+ def has_state(self) -> bool:
+ return True
+
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ predict_x0: bool = True,
+ solver_type: str = "bh2",
+ lower_order_final: bool = True,
+ disable_corrector: List[int] = [],
+ solver_p: Optional[FlaxSchedulerMixin] = None,
+ use_karras_sigmas: Optional[bool] = False,
+ use_exponential_sigmas: Optional[bool] = False,
+ use_beta_sigmas: Optional[bool] = False,
+ use_flow_sigmas: Optional[bool] = False,
+ flow_shift: Optional[float] = 1.0,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ final_sigmas_type: Optional[str] = "zero",
+ rescale_zero_terminal_snr: bool = False,
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ self.dtype = dtype
+ params = locals().copy()
+ params.pop("self") # η§»ι€ self
+ self.dtype = params.pop("dtype") # dtype δΈζΎε¨ config ι
+
+ self.config = SimpleNamespace(**params)
+
+ # # Validation checks from original __init__
+ # if self.config.use_beta_sigmas and not is_scipy_available():
+ # raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
+ if (
+ sum(
+ [
+ self.config.use_beta_sigmas,
+ self.config.use_exponential_sigmas,
+ self.config.use_karras_sigmas,
+ ]
+ )
+ > 1
+ ):
+ raise ValueError(
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
+ )
+ if self.config.solver_type not in ["bh1", "bh2"]:
+ raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}")
+
+ def create_state(self, common: Optional[CommonSchedulerState] = None) -> UniPCMultistepSchedulerState:
+ if common is None:
+ common = CommonSchedulerState.create(self)
+
+ if not self.config.rescale_zero_terminal_snr:
+ # Close to 0 without being 0 so first sigma is not inf
+ # FP16 smallest positive subnormal works well here
+ alphas_cumprod = common.alphas_cumprod
+ alphas_cumprod = alphas_cumprod.at[-1].set(2**-24)
+ common = common.replace(alphas_cumprod=alphas_cumprod)
+
+ # Currently we only support VP-type noise schedule
+ alpha_t = jnp.sqrt(common.alphas_cumprod)
+ sigma_t = jnp.sqrt(1 - common.alphas_cumprod)
+ lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t)
+ sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
+
+ # standard deviation of the initial noise distribution
+ init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
+
+ if self.config.solver_type not in ["bh1", "bh2"]:
+ if self.config.solver_type in ["midpoint", "heun", "logrho"]:
+ self.config.solver_type = "bh2"
+ else:
+ raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}")
+
+ return UniPCMultistepSchedulerState.create(
+ common_state=common,
+ alpha_t=alpha_t,
+ sigma_t=sigma_t,
+ lambda_t=lambda_t,
+ sigmas=sigmas,
+ init_noise_sigma=init_noise_sigma,
+ )
+
+ def set_begin_index(
+ self, state: UniPCMultistepSchedulerState, begin_index: int = 0
+ ) -> UniPCMultistepSchedulerState:
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+ """
+ return state.replace(begin_index=begin_index)
+
+ def set_timesteps(
+ self,
+ state: UniPCMultistepSchedulerState,
+ num_inference_steps: int,
+ shape: Tuple,
+ ) -> UniPCMultistepSchedulerState:
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ """
+ #### Copied from scheduling_dmpsolver_multistep_flax
+ last_timestep = self.config.num_train_timesteps
+ if self.config.timestep_spacing == "linspace":
+ timesteps = jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32)
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = last_timestep // (num_inference_steps + 1)
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (
+ (jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32)
+ )
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ # initial running values
+ sigmas = state.sigmas
+
+ # TODO
+ # # Apply Karras/Exponential/Beta/Flow Sigmas if configured
+ if self.config.use_karras_sigmas:
+ # sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64)
+ raise NotImplementedError("`use_karras_sigmas` is not implemented in JAX version yet.")
+ elif self.config.use_exponential_sigmas:
+ # sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64)
+ raise NotImplementedError("`use_exponential_sigmas` is not implemented in JAX version yet.")
+ elif self.config.use_beta_sigmas:
+ # sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64)
+ raise NotImplementedError("`use_beta_sigmas` is not implemented in JAX version yet.")
+ if self.config.use_flow_sigmas:
+ alphas = jnp.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = jnp.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
+ timesteps = (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64)
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = sigmas[-1]
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+ sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype(jnp.float32)
+ else: # Default case if none of the specialized sigmas are used
+ sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = ((1 - state.common.alphas_cumprod[0]) / state.common.alphas_cumprod[0]) ** 0.5
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+ sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype(jnp.float32)
+
+ model_outputs = jnp.zeros((self.config.solver_order, *shape), dtype=self.dtype)
+ timestep_list = jnp.zeros((self.config.solver_order,), dtype=jnp.int32) # Timesteps are integers
+ # Update the state with the new schedule and re-initialized history
+ return state.replace(
+ timesteps=timesteps,
+ sigmas=sigmas,
+ model_outputs=model_outputs,
+ timestep_list=timestep_list,
+ lower_order_nums=0, # Reset counters for a new inference run
+ step_index=None,
+ begin_index=None,
+ last_sample=None,
+ this_order=0,
+ )
+
+ def convert_model_output(
+ self,
+ state: UniPCMultistepSchedulerState,
+ model_output: jnp.ndarray,
+ sample: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """
+ Converts the model output based on the prediction type and current state.
+ """
+ sigma = state.sigmas[state.step_index] # Current sigma
+
+ # Ensure sigma is a JAX array for _sigma_to_alpha_sigma_t
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+
+ if self.config.predict_x0:
+ if self.config.prediction_type == "epsilon":
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ elif self.config.prediction_type == "flow_prediction":
+ # Original code has `sigma_t = self.sigmas[self.step_index]`.
+ # This implies current sigma `sigma` is used as sigma_t for flow.
+ x0_pred = sample - sigma * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ raise NotImplementedError("Dynamic thresholding isn't implemented.")
+ # x0_pred = self._threshold_sample(x0_pred)
+ return x0_pred
+ else: # self.config.predict_x0 is False
+ if self.config.prediction_type == "epsilon":
+ return model_output
+ elif self.config.prediction_type == "sample":
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ return epsilon
+ elif self.config.prediction_type == "v_prediction":
+ epsilon = alpha_t * model_output + sigma_t * sample
+ return epsilon
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the UniPCMultistepScheduler."
+ )
+
+ def multistep_uni_p_bh_update(
+ self,
+ state: UniPCMultistepSchedulerState,
+ model_output: jnp.ndarray,
+ sample: jnp.ndarray,
+ order: int,
+ ) -> jnp.ndarray:
+ """
+ One step for the UniP (B(h) version) - the Predictor.
+ """
+ if self.config.solver_p:
+ raise NotImplementedError("Nested `solver_p` is not implemented in JAX version yet.")
+
+ m0 = state.model_outputs[self.config.solver_order - 1] # Most recent stored converted model output
+ x = sample
+
+ sigma_t_val, sigma_s0_val = (
+ state.sigmas[state.step_index + 1],
+ state.sigmas[state.step_index],
+ )
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val)
+
+ lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10)
+ lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10)
+
+ h = lambda_t - lambda_s0
+
+ def rk_d1_loop_body(i, carry):
+ # Loop from i = 0 to order-2
+ rks, D1s = carry
+ history_idx = self.config.solver_order - 2 - i
+ mi = state.model_outputs[history_idx]
+ si_val = state.timestep_list[history_idx]
+
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(state.sigmas[self.index_for_timestep(state, si_val)])
+ lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10)
+
+ rk = (lambda_si - lambda_s0) / h
+ Di = (mi - m0) / rk
+
+ rks = rks.at[i].set(rk)
+ D1s = D1s.at[i].set(Di)
+ return rks, D1s
+
+ rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype)
+ D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype)
+ if self.config.solver_order == 1:
+ # Dummy D1s array. It will not be used if order == 1
+ D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype)
+ rks, D1s = jax.lax.fori_loop(0, order - 1, rk_d1_loop_body, (rks_init, D1s_init))
+ rks = rks.at[order - 1].set(1.0)
+
+ hh = -h if self.config.predict_x0 else h
+ h_phi_1 = jnp.expm1(hh)
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = jnp.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ def rb_loop_body(i, carry):
+ R, b, current_h_phi_k, factorial_val = carry
+ R = R.at[i].set(jnp.power(rks, i))
+ b = b.at[i].set(current_h_phi_k * factorial_val / B_h)
+
+ def update_fn(vals):
+ _h_phi_k, _fac = vals
+ next_fac = _fac * (i + 2)
+ next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac
+ return next_h_phi_k, next_fac
+
+ current_h_phi_k, factorial_val = jax.lax.cond(
+ i < order - 1,
+ update_fn,
+ lambda vals: vals,
+ (current_h_phi_k, factorial_val),
+ )
+ return R, b, current_h_phi_k, factorial_val
+
+ R_init = jnp.zeros((self.config.solver_order, self.config.solver_order), dtype=h.dtype)
+ b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype)
+ init_h_phi_k = h_phi_1 / hh - 1.0
+ init_factorial = 1.0
+ R, b, _, _ = jax.lax.fori_loop(0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial))
+
+ if len(D1s) > 0:
+ D1s = jnp.stack(D1s, axis=1) # Resulting shape (B, K, C, H, W)
+
+ def solve_for_rhos_p(R_mat, b_vec, current_order):
+ # Create a mask for the top-left (current_order - 1) x (current_order - 1) sub-matrix
+ mask_size = self.config.solver_order - 1
+ mask = jnp.arange(mask_size) < (current_order - 1)
+ mask_2d = mask[:, None] & mask[None, :]
+
+ # Pad R with identity and b with zeros for a safe solve
+ R_safe = jnp.where(
+ mask_2d,
+ R_mat[:mask_size, :mask_size],
+ jnp.eye(mask_size, dtype=R_mat.dtype),
+ )
+ b_safe = jnp.where(mask, b_vec[:mask_size], 0.0)
+
+ # Solve the system and mask the result
+ solved_rhos = jnp.linalg.solve(R_safe, b_safe)
+ return jnp.where(mask, solved_rhos, 0.0)
+
+ # Handle the special case for order == 2
+ if self.config.solver_order == 1:
+ # Dummy rhos_p_padded for tracing.
+ rhos_p_order2 = jnp.zeros(1, dtype=x.dtype)
+ else:
+ rhos_p_order2 = jnp.zeros(self.config.solver_order - 1, dtype=x.dtype).at[0].set(0.5)
+
+ # Get the result for the general case
+ rhos_p_general = solve_for_rhos_p(R, b, order)
+
+ # Select the appropriate result based on the order
+ rhos_p = jnp.where(order == 2, rhos_p_order2, rhos_p_general)
+
+ pred_res = jax.lax.cond(
+ order > 1,
+ lambda _: jnp.einsum("k,bkc...->bc...", rhos_p, D1s).astype(x.dtype),
+ # False branch: return a zero tensor with the correct shape.
+ lambda _: jnp.zeros_like(x),
+ operand=None,
+ )
+
+ if self.config.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ x_t = x_t_ - alpha_t * B_h * pred_res
+ else: # Predict epsilon
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ x_t = x_t_ - sigma_t * B_h * pred_res
+
+ return x_t.astype(x.dtype)
+
+ def multistep_uni_c_bh_update(
+ self,
+ state: UniPCMultistepSchedulerState,
+ this_model_output: jnp.ndarray,
+ last_sample: jnp.ndarray, # Sample after predictor `x_{t-1}`
+ this_sample: jnp.ndarray, # Sample before corrector `x_t` (after predictor step)
+ order: int,
+ ) -> jnp.ndarray:
+ """
+ One step for the UniC (B(h) version) - the Corrector.
+ """
+ model_output_list = state.model_outputs
+ m0 = model_output_list[self.config.solver_order - 1] # Most recent model output from history
+
+ if last_sample is not None:
+ x = last_sample
+ else:
+ # If it's None, create dummy data. This is for the tracing purpose
+ x = jnp.zeros_like(this_sample)
+
+ x_t = this_sample
+
+ model_t = this_model_output
+
+ sigma_t_val = state.sigmas[state.step_index]
+ sigma_s0_val = state.sigmas[state.step_index - 1]
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val)
+
+ lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10)
+ lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10)
+
+ h = lambda_t - lambda_s0
+
+ def rk_d1_loop_body(i, carry):
+ # Loop from i = 0 to order-1.
+ rks, D1s = carry
+
+ # Get history from state buffer
+ history_idx = self.config.solver_order - (i + 2)
+ mi = state.model_outputs[history_idx]
+ si_val = state.timestep_list[history_idx] # This is the actual timestep value
+
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(state.sigmas[self.index_for_timestep(state, si_val)])
+ lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10)
+
+ rk = (lambda_si - lambda_s0) / h
+ Di = (mi - m0) / rk
+
+ # Update pre-allocated arrays
+ rks = rks.at[i].set(rk)
+ D1s = D1s.at[i].set(Di)
+ return rks, D1s
+
+ # Pre-allocate arrays to max possible size
+ rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype)
+ D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype)
+ if self.config.solver_order == 1:
+ # Dummy D1s array. It will not be used if order == 1. This is for tracing.
+ D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype)
+
+ # Run the loop up to `order - 1`
+ rks, D1s = jax.lax.fori_loop(0, order - 1, rk_d1_loop_body, (rks_init, D1s_init))
+
+ rks = rks.at[order - 1].set(1.0)
+
+ hh = -h if self.config.predict_x0 else h
+ h_phi_1 = jnp.expm1(hh)
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = jnp.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ def rb_loop_body(i, carry):
+ # Loop from i = 0 to order-1
+ R, b, current_h_phi_k, factorial_val = carry
+
+ R = R.at[i].set(jnp.power(rks, i))
+ b = b.at[i].set(current_h_phi_k * factorial_val / B_h)
+
+ # Conditionally update phi_k and factorial for the next iteration
+ def update_fn(vals):
+ # This branch is taken if i < order - 1
+ _h_phi_k, _fac = vals
+ next_fac = _fac * (i + 2)
+ next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac
+ return next_h_phi_k, next_fac
+
+ current_h_phi_k, factorial_val = jax.lax.cond(
+ i < order - 1,
+ update_fn, # If true, update values
+ lambda vals: vals, # If false, pass through
+ (current_h_phi_k, factorial_val),
+ )
+ return R, b, current_h_phi_k, factorial_val
+
+ # Pre-allocate R and b to max size
+ R_init = jnp.zeros((self.config.solver_order, self.config.solver_order), dtype=h.dtype)
+ b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype)
+
+ # Initialize loop carriers
+ init_h_phi_k = h_phi_1 / hh - 1.0
+ init_factorial = 1.0
+
+ R, b, _, _ = jax.lax.fori_loop(0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial))
+
+ if len(D1s) > 0:
+ D1s = jnp.stack(D1s, axis=1) # (B, K, C, H, W)
+
+ def solve_for_rhos(R_mat, b_vec, current_order):
+ # Create a mask to select the first `current_order` elements
+ mask = jnp.arange(self.config.solver_order) < current_order
+ mask_2d = mask[:, None] & mask[None, :]
+
+ # Pad R with identity and b with zeros to create a safe, full-sized system
+ R_safe = jnp.where(mask_2d, R_mat, jnp.eye(self.config.solver_order, dtype=R_mat.dtype))
+ b_safe = jnp.where(mask, b_vec, 0.0)
+
+ # Solve the full-size system and mask the result
+ solved_rhos = jnp.linalg.solve(R_safe, b_safe)
+ return jnp.where(mask, solved_rhos, 0.0)
+
+ rhos_c_order1 = jnp.zeros(self.config.solver_order, dtype=x_t.dtype).at[0].set(0.5)
+ rhos_c_general = solve_for_rhos(R, b, order)
+ rhos_c = jnp.where(order == 1, rhos_c_order1, rhos_c_general)
+
+ D1_t = model_t - m0
+
+ corr_res = jax.lax.cond(
+ order > 1,
+ lambda _: (jnp.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)),
+ lambda _: jnp.zeros_like(D1_t),
+ operand=None,
+ )
+
+ final_rho = jnp.dot(
+ rhos_c,
+ jax.nn.one_hot(order - 1, self.config.solver_order, dtype=rhos_c.dtype),
+ )
+
+ if self.config.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ x_t = x_t_ - alpha_t * B_h * (corr_res + final_rho * D1_t)
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ x_t = x_t_ - sigma_t * B_h * (corr_res + final_rho * D1_t)
+
+ return x_t.astype(x.dtype)
+
+ def index_for_timestep(
+ self,
+ state: UniPCMultistepSchedulerState,
+ timestep: Union[int, jnp.ndarray],
+ schedule_timesteps: Optional[jnp.ndarray] = None,
+ ) -> int:
+ """ "Gets the step_index for timestep."""
+ if schedule_timesteps is None:
+ schedule_timesteps = state.timesteps
+
+ # QUINN!!
+ # timestep_val = (
+ # timestep.item()
+ # if isinstance(timestep, jnp.ndarray) and timestep.ndim == 0
+ # else timestep
+ # )
+ timestep_val = timestep
+
+ index_candidates = jnp.where(schedule_timesteps == timestep_val, size=1, fill_value=-1)[0]
+
+ step_index = jnp.where(
+ index_candidates[0] == -1, # No match found
+ len(schedule_timesteps) - 1, # Default to last index
+ index_candidates[0],
+ )
+ return step_index
+
+ def _init_step_index(
+ self, state: UniPCMultistepSchedulerState, timestep: Union[int, jnp.ndarray]
+ ) -> UniPCMultistepSchedulerState:
+ """Initializes the step_index counter for the scheduler."""
+ if state.begin_index is None:
+ step_index_val = self.index_for_timestep(state, timestep)
+ return state.replace(step_index=step_index_val)
+ else:
+ return state.replace(step_index=state.begin_index)
+
+ @partial(jax.jit, static_argnums=(0, 5)) # self is static_argnum=0
+ def step(
+ self,
+ state: UniPCMultistepSchedulerState,
+ model_output: jnp.ndarray, # This is the direct output from the diffusion model (e.g., noise prediction)
+ timestep: Union[int, jnp.ndarray], # Current discrete timestep from the scheduler's sequence
+ sample: jnp.ndarray, # Current noisy sample (latent)
+ return_dict: bool = True,
+ generator: Optional[jax.random.PRNGKey] = None, # JAX random key
+ ):
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
+ the multistep UniPC.
+ """
+
+ sample = sample.astype(jnp.float32)
+
+ if state.timesteps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep_scalar = jnp.array(timestep)
+
+ # Initialize step_index if it's the first step
+ if state.step_index is None:
+ state = self._init_step_index(state, timestep_scalar)
+
+ # Determine if corrector should be used
+ use_corrector = (
+ (state.step_index > 0)
+ & (~jnp.isin(state.step_index - 1, jnp.array(self.config.disable_corrector)))
+ & (state.last_sample is not None)
+ )
+
+ # Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type
+ model_output_for_history = self.convert_model_output(state, model_output, sample)
+
+ # Apply corrector if applicable
+ sample = jax.lax.cond(
+ use_corrector,
+ lambda: self.multistep_uni_c_bh_update(
+ state=state,
+ this_model_output=model_output_for_history,
+ last_sample=state.last_sample,
+ this_sample=sample,
+ order=state.this_order,
+ ),
+ lambda: sample,
+ )
+
+ # Update history buffers (model_outputs and timestep_list)
+ # Shift existing elements to the left and add new one at the end.
+ # `state.model_outputs` and `state.timestep_list` are fixed-size arrays.
+ # Example:
+ # t0:[None,...,model_output0]
+ # t1:[None,..model_output0,model_output1]
+ # ...
+ # tn:[model_output0,model_output1,...,model_output_n]
+ def step_idx0_branch():
+ updated_model_outputs_history = state.model_outputs.at[-1].set(model_output_for_history)
+ updated_timestep_list_history = state.timestep_list.at[-1].set(timestep_scalar)
+ return updated_model_outputs_history, updated_timestep_list_history
+
+ def non_step_idx0_branch():
+ updated_model_outputs_history = jnp.roll(state.model_outputs, shift=-1, axis=0)
+ updated_model_outputs_history = updated_model_outputs_history.at[-1].set(model_output_for_history)
+
+ updated_timestep_list_history = jnp.roll(state.timestep_list, shift=-1)
+ updated_timestep_list_history = updated_timestep_list_history.at[-1].set(timestep_scalar)
+ return updated_model_outputs_history, updated_timestep_list_history
+
+ updated_model_outputs_history, updated_timestep_list_history = jax.lax.cond(
+ state.step_index == 0, step_idx0_branch, non_step_idx0_branch
+ )
+ state = state.replace(
+ model_outputs=updated_model_outputs_history,
+ timestep_list=updated_timestep_list_history,
+ )
+
+ # Determine the order for the current step (warmup phase logic)
+ this_order = jnp.where(
+ self.config.lower_order_final,
+ jnp.minimum(self.config.solver_order, len(state.timesteps) - state.step_index),
+ self.config.solver_order,
+ )
+
+ # Warmup for multistep: `this_order` can't exceed `lower_order_nums + 1`
+ new_this_order = jnp.minimum(this_order, state.lower_order_nums + 1)
+ state = state.replace(this_order=new_this_order)
+
+ # Store current sample as `last_sample` for the *next* step's corrector
+ state = state.replace(last_sample=sample)
+
+ # UniP predictor step
+ prev_sample = self.multistep_uni_p_bh_update(
+ state=state,
+ model_output=model_output,
+ sample=sample,
+ order=state.this_order,
+ )
+
+ # Update lower_order_nums for warmup
+ new_lower_order_nums = jnp.where(
+ state.lower_order_nums < self.config.solver_order,
+ state.lower_order_nums + 1,
+ state.lower_order_nums,
+ )
+ state = state.replace(lower_order_nums=new_lower_order_nums)
+ # Upon completion, increase step index by one
+ state = state.replace(step_index=state.step_index + 1)
+
+ # Return the updated sample and state
+ if not return_dict:
+ return (prev_sample, state)
+
+ return prev_sample, state
+
+ def scale_model_input(
+ self, state: UniPCMultistepSchedulerState, sample: jnp.ndarray, *args, **kwargs
+ ) -> jnp.ndarray:
+ """
+ UniPC does not scale model input, so it returns the sample unchanged.
+ """
+ return sample
+
+ def add_noise(
+ self,
+ state: UniPCMultistepSchedulerState,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ return add_noise_common(state.common, original_samples, noise, timesteps)
+
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ if self.config.use_flow_sigmas:
+ alpha_t = 1 - sigma
+ sigma_t = sigma
+ else:
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
+ sigma_t = sigma * alpha_t
+
+ return alpha_t, sigma_t
+
+ def __len__(self) -> int:
+ return self.config.num_train_timesteps
diff --git a/bonsai/models/wan2/vae_wan.py b/bonsai/models/wan2/vae_wan.py
new file mode 100644
index 00000000..d1d9ba8b
--- /dev/null
+++ b/bonsai/models/wan2/vae_wan.py
@@ -0,0 +1,738 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Wan-VAE: Video Variational Autoencoder for Wan2.1-T2V-1.3B.
+
+This module provides a JAX/Flax implementation of the Wan-VAE decoder,
+which converts latent representations to RGB video frames.
+
+Architecture (based on reference implementation):
+- Denormalization with learned mean/std
+- Frame-by-frame decoding with temporal upsampling
+- CausalConv3d for temporal coherence
+- Spatial upsampling from 60x60 to 832x480
+- Temporal upsampling from 21 to 81 frames
+"""
+
+from dataclasses import dataclass
+from typing import Tuple
+
+import imageio
+import jax
+import jax.numpy as jnp
+from flax import nnx
+from jax.lax import Precision
+from jaxtyping import Array, Union
+
+CACHE_T = 2
+
+
+@dataclass
+class VAEConfig:
+ """Configuration for Wan-VAE decoder.
+
+ Latent denormalization constants from reference implementation.
+ These are fixed constants computed during VAE training.
+ """
+
+ latent_mean: Tuple[float, ...] = (
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921,
+ )
+
+ latent_std: Tuple[float, ...] = (
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.9160,
+ )
+
+
+class CausalConv3d(nnx.Module):
+ """Causal 3D convolution that doesn't look into the future."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Tuple[int, int, int] = (3, 3, 3),
+ *,
+ rngs: nnx.Rngs,
+ padding: Tuple[int, int, int] = (0, 0, 0),
+ ):
+ self.kernel_size = kernel_size
+ self.temporal_padding = padding[0] # Save for cache size calculation
+ self.conv = nnx.Conv(
+ in_features=in_channels,
+ out_features=out_channels,
+ kernel_size=kernel_size,
+ padding="VALID", # We'll handle padding manually
+ rngs=rngs,
+ precision=Precision.HIGHEST,
+ )
+ self.padding = (
+ (0, 0),
+ (2 * padding[0], 0),
+ (padding[1], padding[1]),
+ (padding[2], padding[2]),
+ (0, 0),
+ )
+
+ def __call__(self, x: Array, cache: Array | None = None) -> tuple[Array, Array | None]:
+ """Forward pass with optional caching.
+
+ Args:
+ x: [B, T, H, W, C] input (JAX channel-last format)
+ cache: [B, CACHE_T, H, W, C] cached frames from previous call, or None
+
+ Returns:
+ out: [B, T_out, H_out, W_out, C_out] output
+ new_cache: [B, CACHE_T, H, W, C] cache for next call, or None
+ """
+ # Cache size is 2*padding because we pad left by (2*padding, 0) for causality
+ cache_t = 2 * self.temporal_padding
+ if cache is not None and cache_t > 0:
+ x = jnp.concatenate([cache, x], axis=1) # [B, T+CACHE_T, H, W, C]
+ if self.conv.in_features != self.conv.out_features:
+ jax.debug.print("CausalConv3d with cache input has nan:{}", jnp.isnan(x).any(), ordered=True)
+ jax.debug.print("feat cache in causalconv3d:{},{}", cache.shape, x.shape, ordered=True)
+ padding = list(self.padding)
+ padding[1] = (max(0, self.padding[1][0] - cache.shape[1]), 0) # Reduce left padding
+ padding = tuple(padding)
+ else:
+ padding = self.padding
+
+ x_padded = jnp.pad(x, padding, mode="constant")
+ if self.conv.in_features != self.conv.out_features:
+ jax.debug.print("CausalConv3d input has nan:{}", jnp.isnan(x_padded).any(), ordered=True)
+ out = self.conv(x_padded)
+ if self.conv.in_features != self.conv.out_features:
+ jax.debug.print("CausalConv3d output has nan:{}", jnp.isnan(out).any(), ordered=True)
+
+ # Extract cache for next iteration: last cache_t frames of INPUT (before conv)
+ # Always create cache if we have temporal padding (even on first frame)
+ if cache_t > 0:
+ new_cache = x[:, -cache_t:, :, :, :] # [B, <=CACHE_T, H, W, C]
+ # Pad on the left if we do not yet have cache_t frames (e.g., first call with T=1).
+ if new_cache.shape[1] < cache_t:
+ pad_t = cache_t - new_cache.shape[1]
+ new_cache = jnp.pad(new_cache, ((0, 0), (pad_t, 0), (0, 0), (0, 0), (0, 0)), mode="constant")
+ else:
+ new_cache = None
+
+ return out, new_cache
+
+
+class RMSNorm(nnx.Module):
+ """RMS Normalization with L2 normalize and learned scale.
+
+ Based on F.normalize approach: normalize to unit norm, then scale.
+ For videos (images=False), uses 3D spatial+temporal normalization.
+ """
+
+ def __init__(self, dim: int, *, rngs: nnx.Rngs):
+ self.scale_factor = dim**0.5
+ # gamma shape: (dim,) will broadcast to [B, T, H, W, C] or [B, H, W, C]
+ self.scale = nnx.Param(jnp.ones(dim))
+ self.eps = 1e-12
+
+ def __call__(self, x: Array) -> Array:
+ # x: [B, T, H, W, C] for 3D or [B, H, W, C] for 2D
+ # Normalize to unit RMS along the channel dimension manually since jax.nn.normalize is unavailable.
+ rms = jnp.sqrt(jnp.sum(jnp.square(x), axis=-1, keepdims=True) + self.eps)
+ # if jnp.isnan(x).any():
+ # nan_mask_x = jnp.isnan(x)
+ # nan_indices_x = jnp.argwhere(nan_mask_x, size=5, fill_value=-1)
+ x_normalized = x / rms
+ # nan_mask = jnp.isnan(x_normalized)
+ # nan_indices = jnp.argwhere(nan_mask, size=5, fill_value=-1)
+ # jax.debug.print(
+ # "x_normalized NaN at indices?:{}, nan indices:{}, rum values:{}, x NaN?: {}, x NaN indices:{}, ",
+ # nan_mask.any(),
+ # nan_indices[:5],
+ # rms.mean(),
+ # nan_mask_x.any(),
+ # nan_indices_x[:5],
+ # )
+ # jax.debug.print("x_normalized has nan: {}", jnp.isnan(x_normalized).any())
+ # jax.debug.print("scale values: {} {}", self.scale_factor, self.scale.value.mean())
+ return x_normalized * self.scale_factor * self.scale.value
+
+
+class ResidualBlock(nnx.Module):
+ """Residual block with RMSNorm and SiLU activation."""
+
+ def __init__(self, in_channels: int, out_channels: int, *, rngs: nnx.Rngs):
+ self.norm1 = RMSNorm(in_channels, rngs=rngs)
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1), rngs=rngs)
+ self.norm2 = RMSNorm(out_channels, rngs=rngs)
+ self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1), rngs=rngs)
+
+ if in_channels != out_channels:
+ self.skip_conv = CausalConv3d(in_channels, out_channels, kernel_size=(1, 1, 1), rngs=rngs)
+ else:
+ self.skip_conv = None
+
+ def __call__(
+ self, x: Array, cache_list: tuple[Array | None, ...] | None = None, cache_idx: list[int] | None = None
+ ) -> tuple[Array, tuple[Array | None, ...] | None]:
+ residual = x
+ x = self.norm1(x)
+ # if self.skip_conv is not None:
+ # jax.debug.print("Residual block norm1 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+ x = nnx.silu(x)
+ # if self.skip_conv is not None:
+ # jax.debug.print("Residual block activation output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ if cache_list is not None:
+ idx = cache_idx[0]
+ x, new_cache = self.conv1(x, cache_list[idx])
+ cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :])
+ cache_idx[0] += 1
+ # if self.skip_conv is not None:
+ # jax.debug.print("Residual block conv1 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+ else:
+ x, _ = self.conv1(x, None)
+ # if self.skip_conv is not None:
+ # jax.debug.print("no cache: Residual block conv1 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ x = self.norm2(x)
+ # if self.skip_conv is not None:
+ # jax.debug.print("Residual block norm2 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+ x = nnx.silu(x)
+ # if self.skip_conv is not None:
+ # jax.debug.print("Residual block activation2 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ if cache_list is not None:
+ idx = cache_idx[0]
+ x, new_cache = self.conv2(x, cache_list[idx])
+ cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :])
+ cache_idx[0] += 1
+ # if self.skip_conv is not None:
+ # jax.debug.print("Residual block conv2 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+ else:
+ x, _ = self.conv2(x, None)
+ # if self.skip_conv is not None:
+ # jax.debug.print("no cache: Residual block conv2 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ if self.skip_conv is not None:
+ residual, _ = self.skip_conv(residual, None)
+ # jax.debug.print("Residual conv output has nan:{}", jnp.isnan(residual).any(), ordered=True)
+
+ # jax.debug.print("Residual block output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ return x + residual, cache_list
+
+
+class AttentionBlock(nnx.Module):
+ """Spatial attention block with batched frame processing."""
+
+ def __init__(self, channels: int, *, rngs: nnx.Rngs):
+ self.norm = RMSNorm(channels, rngs=rngs)
+ self.qkv = nnx.Conv(
+ in_features=channels,
+ out_features=channels * 3,
+ kernel_size=(1, 1),
+ use_bias=True,
+ rngs=rngs,
+ precision=Precision.HIGHEST,
+ )
+ self.proj = nnx.Conv(
+ in_features=channels,
+ out_features=channels,
+ kernel_size=(1, 1),
+ use_bias=True,
+ rngs=rngs,
+ precision=Precision.HIGHEST,
+ )
+
+ def __call__(self, x: Array) -> Array:
+ # x: [B, T, H, W, C]
+ b, t, h, w, c = x.shape
+ residual = x
+
+ x = x.reshape(b * t, h, w, c)
+ x = self.norm(x)
+ # QKV projection: [B*T, H, W, C] -> [B*T, H, W, 3*C]
+ qkv = self.qkv(x)
+
+ # Reshape for attention: [B*T, H, W, 3*C] -> [B*T, H*W, 3*C] -> split to Q, K, V
+ qkv = qkv.reshape(b * t, h * w, 3 * c)
+ q, k, v = jnp.split(qkv, 3, axis=-1) # Each: [B*T, H*W, C]
+
+ # Scaled dot-product attention
+ scale = c**-0.5
+ attn = jax.nn.softmax(jnp.einsum("bic,bjc->bij", q, k) * scale, axis=-1) # [B*T, H*W, H*W]
+ out = jnp.einsum("bij,bjc->bic", attn, v) # [B*T, H*W, C]
+
+ # Reshape back to spatial: [B*T, H*W, C] -> [B*T, H, W, C]
+ out = out.reshape(b * t, h, w, c)
+
+ # Output projection
+ out = self.proj(out)
+
+ # Reshape back to video: [B*T, H, W, C] -> [B, T, H, W, C]
+ out = out.reshape(b, t, h, w, c)
+
+ return out + residual
+
+
+class Upsample2D(nnx.Module):
+ """Spatial 2x upsample that also halves channels, mirroring torch Resample."""
+
+ def __init__(self, in_channels: int, out_channels: int, *, rngs: nnx.Rngs):
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.conv = nnx.Conv(
+ in_features=in_channels,
+ out_features=out_channels,
+ kernel_size=(3, 3),
+ padding=1,
+ rngs=rngs,
+ precision=Precision.HIGHEST,
+ )
+
+ def __call__(self, x: Array) -> Array:
+ # x: [B, T, H, W, Cin]
+ b, t, h, w, _ = x.shape
+ orig_dtype = x.dtype
+ x = x.reshape(b * t, h, w, self.in_channels)
+ x = jax.image.resize(x.astype(jnp.float32), (b * t, h * 2, w * 2, self.in_channels), method="nearest").astype(
+ orig_dtype
+ )
+ x = self.conv(x)
+ return x.reshape(b, t, h * 2, w * 2, self.out_channels)
+
+
+class Upsample3D(nnx.Module):
+ """Temporal+spatial 2x upsample with channel reduction (like torch Resample)."""
+
+ def __init__(self, in_channels: int, out_channels: int, *, rngs: nnx.Rngs):
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.time_conv = CausalConv3d(in_channels, in_channels * 2, kernel_size=(3, 1, 1), padding=(1, 0, 0), rngs=rngs)
+ self.spatial_conv = nnx.Conv(
+ in_features=in_channels,
+ out_features=out_channels,
+ kernel_size=(3, 3),
+ padding=1,
+ rngs=rngs,
+ precision=Precision.HIGHEST,
+ )
+
+ def __call__(
+ self, x: Array, cache_list: tuple[Array | None, ...] | None = None, cache_idx: list[int] | None = None
+ ) -> tuple[Array, tuple[Array | None, ...] | None]:
+ b, t, h, w, _ = x.shape
+ orig_dtype = x.dtype
+
+ if cache_list is not None:
+ idx = cache_idx[0]
+
+ # First frame: skip time_conv, only do spatial upsampling
+ if cache_list[idx] is None:
+ # Use zero array as sentinel with SAME shape as real cache
+ # This ensures consistent pytree structure for JIT
+ # We use zeros with shape [B, 2, H, W, C] where 2 = cache size for 3x1x1 conv
+ sentinel = jnp.zeros((b, 2, h, w, self.in_channels), dtype=orig_dtype)
+ cache_list = (*cache_list[:idx], sentinel, *cache_list[idx + 1 :])
+ cache_idx[0] += 1
+ t_out = t
+ else:
+ # Always pass the cached features (including the zero sentinel) so the
+ # time_conv sees a length-2 cache and returns a length-2 cache, matching
+ # the torch behavior where the sentinel seeds the cache.
+ x, new_cache = self.time_conv(x, cache_list[idx])
+
+ cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :])
+ cache_idx[0] += 1
+
+ x = x.reshape(b, t, h, w, 2, self.in_channels)
+ x = jnp.moveaxis(x, 4, 2) # [B, T, 2, H, W, Cin] -> [B, 2, T, H, W, Cin]
+ t_out = t * 2
+ x = x.reshape(b, t_out, h, w, self.in_channels)
+
+ # Spatial upsampling (always applied)
+ bt = b * t_out
+ x = x.reshape(bt, h, w, self.in_channels)
+ x = jax.image.resize(x.astype(jnp.float32), (bt, h * 2, w * 2, self.in_channels), method="nearest").astype(
+ orig_dtype
+ )
+ x = self.spatial_conv(x)
+ return x.reshape(b, t_out, h * 2, w * 2, self.out_channels), cache_list
+
+
+class Decoder3D(nnx.Module):
+ """
+ 3D Decoder matching reference implementation.
+ Upsamples from [B, 1, 104, 60, 16] -> [B, 4, 832, 480, 3] (JAX format)
+ """
+
+ def __init__(self, *, rngs: nnx.Rngs):
+ # Initial convolution: 16 -> 384
+ self.conv_in = CausalConv3d(16, 384, kernel_size=(3, 3, 3), rngs=rngs, padding=(1, 1, 1))
+
+ # Middle blocks (at lowest resolution)
+ self.mid_block1 = ResidualBlock(384, 384, rngs=rngs)
+ self.mid_attn = AttentionBlock(384, rngs=rngs)
+ self.mid_block2 = ResidualBlock(384, 384, rngs=rngs)
+
+ # Upsample stages (match torch checkpoint shapes)
+ # Stage 0: stay at 384, then upsample to 192 channels
+ self.up_blocks_0 = nnx.List(
+ [
+ ResidualBlock(384, 384, rngs=rngs),
+ ResidualBlock(384, 384, rngs=rngs),
+ ResidualBlock(384, 384, rngs=rngs),
+ ]
+ )
+ self.up_sample_0 = Upsample3D(384, 192, rngs=rngs)
+
+ # Stage 1: 192 -> 384 (first block), remain at 384, then upsample to 192
+ self.up_blocks_1 = nnx.List(
+ [
+ ResidualBlock(192, 384, rngs=rngs),
+ ResidualBlock(384, 384, rngs=rngs),
+ ResidualBlock(384, 384, rngs=rngs),
+ ]
+ )
+ self.up_sample_1 = Upsample3D(384, 192, rngs=rngs)
+
+ # Stage 2: stay at 192, then spatial-only upsample to 96
+ self.up_blocks_2 = nnx.List(
+ [
+ ResidualBlock(192, 192, rngs=rngs),
+ ResidualBlock(192, 192, rngs=rngs),
+ ResidualBlock(192, 192, rngs=rngs),
+ ]
+ )
+ self.up_sample_2 = Upsample2D(192, 96, rngs=rngs)
+
+ # Stage 3: 96 -> 96, no upsample
+ self.up_blocks_3 = nnx.List(
+ [
+ ResidualBlock(96, 96, rngs=rngs),
+ ResidualBlock(96, 96, rngs=rngs),
+ ResidualBlock(96, 96, rngs=rngs),
+ ]
+ )
+
+ # Output head: 96 -> 3
+ self.norm_out = RMSNorm(96, rngs=rngs)
+ self.conv_out = CausalConv3d(96, 3, kernel_size=(3, 3, 3), padding=(1, 1, 1), rngs=rngs)
+
+ def __call__(
+ self, z: Array, cache_list: tuple[Array | None, ...] | None = None, cache_idx: list[int] | None = None
+ ) -> tuple[Array, tuple[Array | None, ...] | None]:
+ """Forward pass with optional caching.
+
+ Args:
+ z: [B, T, H, W, C] latent (e.g., [1, 1, 104, 60, 16])
+ cache_list: Tuple of cached features for all conv layers, or None
+ cache_idx: List containing current index in cache_list (mutable), or None
+
+ Returns:
+ x: [B, T_out, H_out, W_out, 3] RGB video (e.g., [1, 4, 832, 480, 3])
+ cache_list: Updated cache tuple
+ """
+ # Initial convolution
+ if cache_list is not None:
+ idx = cache_idx[0]
+ x, new_cache = self.conv_in(z, cache_list[idx])
+ cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :])
+ cache_idx[0] += 1
+ else:
+ x, _ = self.conv_in(z, None)
+
+ # jax.debug.print("Decoder3D conv_in output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ # Middle blocks
+ x, cache_list = self.mid_block1(x, cache_list, cache_idx)
+ x = self.mid_attn(x) # Attention doesn't use cache
+ x, cache_list = self.mid_block2(x, cache_list, cache_idx)
+
+ # jax.debug.print("Decoder3D mid output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ # Upsample stage 0
+ for block in self.up_blocks_0:
+ x, cache_list = block(x, cache_list, cache_idx)
+ x, cache_list = self.up_sample_0(x, cache_list, cache_idx)
+
+ # jax.debug.print("Decoder3D upsample0 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ # Upsample stage 1
+ for block in self.up_blocks_1:
+ x, cache_list = block(x, cache_list, cache_idx)
+ x, cache_list = self.up_sample_1(x, cache_list, cache_idx)
+
+ # jax.debug.print("Decoder3D upsample1 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ # Upsample stage 2
+ for block in self.up_blocks_2:
+ x, cache_list = block(x, cache_list, cache_idx)
+ x = self.up_sample_2(x) # Spatial-only upsample, no cache
+
+ # jax.debug.print("Decoder3D upsample2 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ # Upsample stage 3 (no spatial upsample)
+ for block in self.up_blocks_3:
+ x, cache_list = block(x, cache_list, cache_idx)
+
+ # jax.debug.print("Decoder3D upsample3 output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ # Output
+ x = self.norm_out(x)
+ x = nnx.silu(x)
+
+ # jax.debug.print("Decoder3D norm_out output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ if cache_list is not None:
+ idx = cache_idx[0]
+ x, new_cache = self.conv_out(x, cache_list[idx])
+ cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :])
+ cache_idx[0] += 1
+ else:
+ x, _ = self.conv_out(x, None)
+
+ # jax.debug.print("Decoder3D conv_out output has nan:{}", jnp.isnan(x).any(), ordered=True)
+
+ return x, cache_list
+
+
+class WanVAEDecoder(nnx.Module):
+ """
+ Wan-VAE Decoder: Converts video latents to RGB frames.
+
+ Architecture matches reference (wan/modules/vae.py:544-568):
+ 1. Denormalize latents with learned mean/std
+ 2. Conv 1x1 projection (16 -> 16 channels)
+ 3. Frame-by-frame decode with Decoder3D
+ 4. Concatenate and clamp output
+
+ Input: [B, T, H, W, C] = [1, 21, 104, 60, 16]
+ Output: [B, T_out, H_out, W_out, 3] = [1, 81, 832, 480, 3]
+ """
+
+ def __init__(self, cfg: VAEConfig = VAEConfig(), *, rngs: nnx.Rngs):
+ # Store config tuples as Python values (not JAX arrays!)
+ # They'll be converted to JAX arrays at runtime in decode()
+ # This avoids ShapeDtypeStruct issues during nnx.eval_shape()
+ self.latent_mean_tuple = cfg.latent_mean
+ self.latent_std_tuple = cfg.latent_std
+
+ # 1x1 conv projection
+ self.conv2 = CausalConv3d(16, 16, kernel_size=(1, 1, 1), rngs=rngs)
+
+ # 3D decoder
+ self.decoder = Decoder3D(rngs=rngs)
+
+ def decode(self, latents: Array) -> Array:
+ """
+ Decode latents to RGB video with feature caching.
+
+ Args:
+ latents: [B, T, H, W, C] latent representation (JAX format)
+ e.g., [1, 21, 104, 60, 16]
+
+ Returns:
+ video: [B, T_out, H_out, W_out, 3] RGB video (values in [-1, 1])
+ e.g., [1, 81, 832, 480, 3]
+ """
+ # Step 1: Denormalize
+ # Convert Python tuples to JAX arrays at runtime (JIT treats them as static constants)
+ latent_mean = jnp.array(self.latent_mean_tuple).reshape(1, 1, 1, 1, 16)
+ latent_std = jnp.array(self.latent_std_tuple).reshape(1, 1, 1, 1, 16)
+ z = latents * latent_std + latent_mean
+
+ z, _ = self.conv2(z, None)
+
+ # Scan over time dimension: z is [B, T, H, W, C], transpose to [T, B, H, W, C]
+ z_frames = jnp.moveaxis(z, 1, 0) # [T, B, H, W, C]
+ # Add singleton time dimension for each frame: [T, B, 1, H, W, C]
+ z_frames = z_frames[:, :, None, :, :, :]
+
+ # jax.debug.print("z_frames has nan:{}", jnp.isnan(z_frames).any())
+
+ # Warm-up pass: process first frame to initialize cache with correct shapes
+ # This ensures consistent pytree structure for jax.lax.scan
+ cache_idx = [0]
+ cache_tuple = (None,) * 50
+ first_frame_out, cache_tuple = self.decoder(z_frames[0], cache_tuple, cache_idx)
+ num_arrays = sum(isinstance(x, jnp.ndarray) for x in cache_tuple)
+ num_nones = sum(x is None for x in cache_tuple)
+ print(f"cache Arrays: {num_arrays},cache Nones: {num_nones}")
+
+ # JIT-compiled scan function for remaining frames (now cache has concrete shapes)
+ @jax.jit
+ def scan_frames(cache_tuple, frame_latent):
+ """Process single frame with caching (JIT-compiled)."""
+ cache_idx = [0]
+ frame_out, new_cache_tuple = self.decoder(frame_latent, cache_tuple, cache_idx)
+ # num_arrays = sum(isinstance(x, jnp.ndarray) for x in new_cache_tuple)
+ # num_nones = sum(x is None for x in new_cache_tuple)
+ # jax.debug.print("new cache Arrays: {},cache Nones: {}", num_arrays, num_nones)
+ # jax.debug.print("frame_out shape:{}", frame_out.shape)
+ # right_part_frame = frame_out[:, :, :, 235:, :]
+ # jax.debug.print("frame_out Has NaN: {}", jnp.isnan(right_part_frame).any())
+ return new_cache_tuple, frame_out
+
+ # Process remaining frames with JIT
+ if z_frames.shape[0] > 1:
+ with jax.disable_jit():
+ _final_cache, remaining_outputs = jax.lax.scan(scan_frames, cache_tuple, z_frames[1:])
+
+ print(f"remaining output shape: {remaining_outputs.shape}")
+ right_part_remaining = remaining_outputs[:, :, :, :, 235:, :]
+ print(f"Has NaN: {jnp.isnan(right_part_remaining).any()}")
+ print(f"Has Inf: {jnp.isinf(right_part_remaining).any()}")
+ print(f"remaining output mean:{right_part_remaining.mean()} ")
+ # Frame 0 outputs 1 frame: [B, 1, H, W, 3]
+ # Frames 1+ each output 4 frames: [T-1, B, 4, H, W, 3]
+ # Flatten temporal dimensions before concatenating
+ b, h_out, w_out, c = (
+ first_frame_out.shape[0],
+ first_frame_out.shape[2],
+ first_frame_out.shape[3],
+ first_frame_out.shape[4],
+ )
+
+ # Flatten first frame: [B, 1, H, W, 3] -> [1, B, H, W, 3]
+ first_flat = first_frame_out.transpose(1, 0, 2, 3, 4) # [1, B, H, W, 3]
+
+ # Flatten remaining frames: [T-1, B, 4, H, W, 3] -> [T-1*4, B, H, W, 3]
+ t_minus_1 = remaining_outputs.shape[0]
+ t_out_per_frame = remaining_outputs.shape[2]
+ remaining_flat = remaining_outputs.transpose(0, 2, 1, 3, 4, 5).reshape(
+ t_minus_1 * t_out_per_frame, b, h_out, w_out, c
+ )
+ print(f"remaining flat shape:{remaining_flat.shape}")
+ print(f"remaining flat mean:{remaining_flat[:, 0, :, 235:, :].mean()} ")
+
+ # Concatenate along time dimension: [1+T-1*4, B, H, W, 3]
+ # Concatenate first frame with remaining frames
+ x = jnp.concatenate([first_flat, remaining_flat], axis=0).transpose(1, 0, 2, 3, 4)
+ else:
+ x = first_frame_out
+
+ # Clamp to [-1, 1]
+ x = jnp.clip(x, -1.0, 1.0)
+
+ return x
+
+
+def load_vae_from_checkpoint(checkpoint_path: str, rngs: nnx.Rngs) -> WanVAEDecoder:
+ """
+ Load Wan-VAE decoder from checkpoint.
+
+ Args:
+ checkpoint_path: Path to the VAE checkpoint directory
+ rngs: Random number generators for initialization
+
+ Returns:
+ WanVAEDecoder with loaded weights
+ """
+ # Create VAE decoder structure
+ vae_decoder = WanVAEDecoder(rngs=rngs)
+
+ # TODO: Implement checkpoint loading
+ # This will require mapping PyTorch VAE weights to JAX
+ # Similar to what's done in params.py for the DiT model
+
+ print("Warning: VAE checkpoint loading not yet implemented")
+ print("Returning VAE with random weights")
+
+ return vae_decoder
+
+
+def decode_latents_to_video(vae_decoder: WanVAEDecoder, latents: Array, normalize: bool = True) -> Array:
+ """
+ Helper function to decode latents and post-process to video.
+
+ Args:
+ vae_decoder: WanVAEDecoder instance
+ latents: [B, T, H, W, C] latent representation
+ normalize: If True, normalize output from [-1, 1] to [0, 255] uint8
+
+ Returns:
+ video: [B, T, H_out, W_out, 3] video frames
+ """
+ # Decode
+ video = vae_decoder.decode(latents)
+ print(f"video shape:{video.shape}")
+ print(f"video mean:{video[0, 1:, :, 235:, :].mean()}")
+
+ if normalize:
+ video = (video + 1.0) / 2.0
+ video = jnp.clip(video, 0.0, 1.0)
+
+ video = jnp.round(video * 255.0)
+ video = jnp.clip(video, 0, 255).astype(jnp.uint8)
+
+ return video
+
+
+def save_video(
+ video: Array,
+ save_path: str,
+ fps: int = 30,
+ codec: str = "libx264",
+ quality: int = 8,
+) -> str | None:
+ try:
+ # Handle batch dimension: take first video if batched
+ assert video.ndim == 5
+ video = video[0] # [T, H, W, C]
+
+ video_np = jax.device_get(video)
+
+ # Write video
+ writer = imageio.get_writer(save_path, fps=fps, codec=codec, quality=quality)
+ for frame in video_np:
+ writer.append_data(frame)
+ writer.close()
+
+ return save_path
+
+ except Exception as e:
+ print(f"Failed to save video: {e}")
+ return None
+
+
+__all__ = ["VAEConfig", "WanVAEDecoder", "decode_latents_to_video", "load_vae_from_checkpoint", "save_video"]
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 00000000..9f89b3a2
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,133 @@
+absl-py==2.3.1
+accelerate==1.12.0
+aiofiles==24.1.0
+aiohappyeyeballs==2.6.1
+aiohttp==3.13.2
+aiosignal==1.4.0
+annotated-doc==0.0.4
+annotated-types==0.7.0
+anyio==4.11.0
+attrs==25.4.0
+brotli==1.2.0
+certifi==2025.11.12
+cffi==2.0.0
+charset-normalizer==3.4.4
+chex==0.1.91
+click==8.3.1
+cryptography==46.0.3
+dashscope==1.25.2
+diffusers==0.35.2
+easydict==1.13
+einops==0.8.1
+etils==1.13.0
+fastapi==0.122.0
+ffmpy==1.0.0
+filelock==3.20.0
+flax==0.12.1
+frozenlist==1.8.0
+fsspec==2025.10.0
+ftfy==6.3.1
+gradio==6.0.1
+gradio-client==2.0.0
+groovy==0.1.2
+h11==0.16.0
+hf-xet==1.2.0
+httpcore==1.0.9
+httpx==0.28.1
+huggingface-hub==0.36.0
+humanize==4.14.0
+idna==3.11
+imageio==2.37.2
+imageio-ffmpeg==0.6.0
+importlib-metadata==8.7.0
+importlib-resources==6.5.2
+inquirerpy==0.3.4
+jax==0.8.1
+-e file:///home/gcpuser/sky_workdir/bonsai
+jaxlib==0.8.1
+jaxtyping==0.3.3
+jinja2==3.1.6
+libtpu==0.0.30
+markdown-it-py==4.0.0
+markupsafe==3.0.3
+mdurl==0.1.2
+ml-dtypes==0.5.4
+modelscope==1.32.0
+mpmath==1.3.0
+msgpack==1.1.2
+multidict==6.7.0
+nest-asyncio==1.6.0
+networkx==3.6
+numpy==2.3.5
+nvidia-cublas-cu12==12.8.4.1
+nvidia-cuda-cupti-cu12==12.8.90
+nvidia-cuda-nvrtc-cu12==12.8.93
+nvidia-cuda-runtime-cu12==12.8.90
+nvidia-cudnn-cu12==9.10.2.21
+nvidia-cufft-cu12==11.3.3.83
+nvidia-cufile-cu12==1.13.1.3
+nvidia-curand-cu12==10.3.9.90
+nvidia-cusolver-cu12==11.7.3.90
+nvidia-cusparse-cu12==12.5.8.93
+nvidia-cusparselt-cu12==0.7.1
+nvidia-nccl-cu12==2.27.5
+nvidia-nvjitlink-cu12==12.8.93
+nvidia-nvshmem-cu12==3.3.20
+nvidia-nvtx-cu12==12.8.90
+opencv-python==4.11.0.86
+opt-einsum==3.4.0
+optax==0.2.6
+orbax-checkpoint==0.11.28
+orjson==3.11.4
+packaging==25.0
+pandas==2.3.3
+pfzy==0.3.4
+pillow==12.0.0
+prompt-toolkit==3.0.52
+propcache==0.4.1
+protobuf==6.33.1
+psutil==7.1.3
+pycparser==2.23
+pydantic==2.12.4
+pydantic-core==2.41.5
+pydub==0.25.1
+pygments==2.19.2
+python-dateutil==2.9.0.post0
+python-multipart==0.0.20
+pytz==2025.2
+pyyaml==6.0.3
+regex==2025.11.3
+requests==2.32.5
+rich==14.2.0
+safehttpx==0.1.7
+safetensors==0.7.0
+scipy==1.16.3
+semantic-version==2.10.0
+setuptools==80.9.0
+shellingham==1.5.4
+simplejson==3.20.2
+six==1.17.0
+sniffio==1.3.1
+starlette==0.50.0
+sympy==1.14.0
+tensorstore==0.1.79
+tokenizers==0.22.1
+tomlkit==0.13.3
+toolz==1.1.0
+torch==2.9.1
+torchvision==0.24.1
+tqdm==4.67.1
+transformers==4.57.3
+treescope==0.1.10
+triton==3.5.1
+typer==0.20.0
+typing-extensions==4.15.0
+typing-inspection==0.4.2
+tzdata==2025.2
+urllib3==2.5.0
+uvicorn==0.38.0
+wadler-lindig==0.1.7
+wcwidth==0.2.14
+websocket-client==1.9.0
+yarl==1.22.0
+zipp==3.23.0
\ No newline at end of file