diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a7a883c..2756ac23 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,12 +8,14 @@ # 'pre-commit run --all' repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.14.1 - hooks: - # Run the linter. - - id: ruff-check - args: [ --fix ] - # Run the formatter. - - id: ruff-format \ No newline at end of file + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.14.1 + hooks: + # Run the linter. + - id: ruff-check + args: [--fix] + exclude: bonsai/models/wan2/tests + # Run the formatter. + - id: ruff-format + exclude: bonsai/models/wan2/tests diff --git a/bonsai/models/umt5/README.md b/bonsai/models/umt5/README.md new file mode 100644 index 00000000..6e97832e --- /dev/null +++ b/bonsai/models/umt5/README.md @@ -0,0 +1,30 @@ +# UMT5 in JAX + +This directory contains a pure JAX implementation of the [UMT5 model](https://arxiv.org/abs/2304.09151), using the [Flax NNX](https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html) API. + + +## Model Configuration Support Status + +| Model Name | Config Support Status | +| :--- | :--- | +| **Dense Models** | | +| [umt5-small](https://huggingface.co/google/umt5-small) | **βœ… Supported** | +| [umt5-base](https://huggingface.co/google/umt5-base) | **βœ… Supported** | +| [umt5-xl](https://huggingface.co/google/umt5-xl) | **βœ… Supported** | +| [umt5-xxl](https://huggingface.co/google/umt5-xxl) | **βœ… Supported** | + + +### Running this model + +Run UMT5 in action, implemented in [300 lines of code](modeling.py) in JAX. + +```sh +python3 -m bonsai.models.umt5.tests.run_model +``` + + +## How to contribute to this model + +We welcome contributions! You can contribute to this model via the following: +* Add a model config variant from the above `🟑 Not started` to `class UMT5Config` in [modeling.py](modeling.py). Make sure your code is runnable on at least one hardware before creating a PR. +* Got some hardware? Run [run_model.py](tests/run_model.py) the existing configs above on hardwares marked `❔ Needs check`. Mark as `βœ… Runs` or `⛔️ Not supported`. diff --git a/bonsai/models/umt5/modeling.py b/bonsai/models/umt5/modeling.py new file mode 100644 index 00000000..0082a281 --- /dev/null +++ b/bonsai/models/umt5/modeling.py @@ -0,0 +1,734 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import dataclasses +import logging +import math +from typing import Optional, Union + +import jax +import jax.numpy as jnp +from flax import nnx +from jax._src.typing import DTypeLike +from jax.lax import Precision + +ACT_FN = { + "gelu": nnx.gelu, + "relu": nnx.relu, +} + + +def fp16_clamp(x: jax.Array): + if x.dtype == jnp.float16 and jnp.isinf(x).any(): + clamp = jnp.finfo(x.dtype).max - 1000 + x = jax.lax.clamp(x=x, min=-clamp, max=clamp) + return x + + +@dataclasses.dataclass +class UMT5Config: + """Configuration for UMT5 model.""" + + vocab_size: int = (250112,) + d_model: int = (512,) + d_kv: int = (64,) + d_ff: int = (1024,) + num_layers: int = (8,) + num_decoder_layers: int = (None,) + num_heads: int = (6,) + relative_attention_num_buckets: int = (32,) + relative_attention_max_distance: int = (128,) + dropout_rate: float = (0.1,) + layer_norm_epsilon: float = (1e-6,) + initializer_factor: float = (1.0,) + feed_forward_proj: str = ("gated-gelu",) + is_encoder_decoder: bool = (True,) + use_cache: bool = (True,) + tokenizer_class: str = ("T5Tokenizer",) + tie_word_embeddings: bool = (True,) + pad_token_id: int = (0,) + eos_token_id: int = (1,) + decoder_start_token_id: int = (0,) + is_decoder: bool = (False,) + dtype: DTypeLike = (jnp.float32,) + + def __post_init__(self): + self.num_decoder_layers = ( + self.num_decoder_layers if self.num_decoder_layers is not None else self.num_layers + ) # default = symmetry + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if (len(act_info) > 1 and act_info[0] != "gated") or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {self.feed_forward_proj} is not a valid activation function of the dense layer. " + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + if self.dense_act_fn not in ACT_FN: + raise ValueError( + f"`feed_forward_proj`: {self.feed_forward_proj} is not a valid activation function of the dense layer. " + f"Supported activation functions are: {', '.join(ACT_FN.keys())}" + ) + + +class T5LayerNorm(nnx.Module): + def __init__( + self, + dim: int, + *, + eps=1e-6, + param_dtype: jnp.dtype | None = jnp.float32, + ): + super().__init__() + self.dim = dim + self.eps = eps + self.scale = nnx.Param(jnp.ones(dim), dtype=param_dtype) + + def __call__(self, hidden_states: jax.Array): + # RMS normalization: hidden_states / sqrt(mean(hidden_states^2)) + variance = jnp.mean(hidden_states.astype(jnp.float32) ** 2, axis=-1, keepdims=True) + hidden_states = hidden_states * jax.lax.rsqrt(variance + self.eps) + weight_dtype = self.scale.get_value().dtype + if weight_dtype in [jnp.float16, jnp.bfloat16]: + hidden_states = hidden_states.astype(weight_dtype) + return self.scale.get_value() * hidden_states + + +class UMT5DenseActDense(nnx.Module): + def __init__(self, config: UMT5Config, *, param_dtype: jnp.dtype | None = jnp.float32, rngs: nnx.Rngs): + super().__init__() + self.param_dtype = param_dtype + self.wi = nnx.Linear( + config.d_model, config.d_ff, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs + ) + self.wo = nnx.Linear( + config.d_ff, config.d_model, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs + ) + self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) + self.act = ACT_FN[config.dense_act_fn] + + def __call__(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class UMT5DenseGatedActDense(nnx.Module): + def __init__(self, config: UMT5Config, *, param_dtype: jnp.dtype | None = jnp.float32, rngs: nnx.Rngs): + super().__init__() + self.param_dtype = param_dtype + self.wi_0 = nnx.Linear( + config.d_model, config.d_ff, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs + ) + self.wi_1 = nnx.Linear( + config.d_model, config.d_ff, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs + ) + self.wo = nnx.Linear( + config.d_ff, config.d_model, precision=Precision.HIGHEST, param_dtype=param_dtype, use_bias=False, rngs=rngs + ) + self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) + self.act = ACT_FN[config.dense_act_fn] + + def __call__(self, hidden_states: jax.Array): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class UMT5LayerFF(nnx.Module): + def __init__(self, config: UMT5Config, *, param_dtype: jnp.dtype | None = jnp.float32, rngs: nnx.Rngs): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = UMT5DenseGatedActDense(config, param_dtype=param_dtype, rngs=rngs) + else: + self.DenseReluDense = UMT5DenseActDense(config, param_dtype=param_dtype, rngs=rngs) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, param_dtype=param_dtype) + self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) + + def __call__(self, hidden_states: jax.Array): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class UMT5Attention(nnx.Module): + """ + T5's attention using relative_attention_bias. + """ + + def __init__( + self, + config: UMT5Config, + *, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + param_dtype: jnp.dtype | None = jnp.float32, + rngs: nnx.Rngs, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.inner_dim = self.n_heads * self.key_value_proj_dim + + self.q = nnx.Linear( + self.d_model, + self.inner_dim, + precision=Precision.HIGHEST, + param_dtype=param_dtype, + use_bias=False, + rngs=rngs, + ) + self.k = nnx.Linear( + self.d_model, + self.inner_dim, + precision=Precision.HIGHEST, + param_dtype=param_dtype, + use_bias=False, + rngs=rngs, + ) + self.v = nnx.Linear( + self.d_model, + self.inner_dim, + precision=Precision.HIGHEST, + param_dtype=param_dtype, + use_bias=False, + rngs=rngs, + ) + self.o = nnx.Linear( + self.inner_dim, + self.d_model, + precision=Precision.HIGHEST, + param_dtype=param_dtype, + use_bias=False, + rngs=rngs, + ) + + self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nnx.Embed( + self.relative_attention_num_buckets, self.n_heads, param_dtype=param_dtype, rngs=rngs + ) + + def _relative_position_bucket(self, rel_pos): + """Convert relative positions to bucket indices.""" + if not self.is_decoder: + num_buckets = self.relative_attention_num_buckets // 2 + rel_buckets = (rel_pos > 0).astype(jnp.int32) * num_buckets + rel_pos = jnp.abs(rel_pos) + else: + num_buckets = self.relative_attention_num_buckets + rel_buckets = 0 + rel_pos = -jnp.minimum(rel_pos, jnp.zeros_like(rel_pos)) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = rel_pos < max_exact + + # Logarithmic bucketing for large positions + rel_pos_large = max_exact + ( + jnp.log(rel_pos.astype(jnp.float32) / max_exact) + / math.log(self.relative_attention_max_distance / max_exact) + * (num_buckets - max_exact) + ).astype(jnp.int32) + rel_pos_large = jnp.minimum(rel_pos_large, num_buckets - 1) + + rel_buckets = rel_buckets + jnp.where(is_small, rel_pos, rel_pos_large) + return rel_buckets + + def compute_bias(self, q_len, k_len): + """Compute binned relative position bias""" + ctx_pos = jnp.arange(q_len, dtype=jnp.int32)[:, None] + mem_pos = jnp.arange(k_len, dtype=jnp.int32)[None, :] + rel_pos = mem_pos - ctx_pos # shape (query_length, key_length) + rel_pos_bkt = self._relative_position_bucket(rel_pos) + values = self.relative_attention_bias(rel_pos_bkt) # shape (query_length, key_length, num_heads) + # shape (1, num_heads, query_length, key_length) + values = values.transpose(2, 0, 1)[None, :, :, :] + return values + + def __call__( + self, + hidden_states: jax.Array, + encoder_hidden_states: Optional[jax.Array] = None, + attention_mask: Optional[jax.Array] = None, + ): + b, n, c = hidden_states.shape[0], self.n_heads, self.key_value_proj_dim + + # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + + q = self.q(hidden_states).reshape(b, -1, n, c) + k = self.k(current_states).reshape(b, -1, n, c) + v = self.v(current_states).reshape(b, -1, n, c) + + # Attention bias + q_len, k_len = q.shape[1], k.shape[1] + if not self.has_relative_attention_bias: + position_bias = jnp.zeros((1, n, q_len, k_len), dtype=q.dtype) + else: + position_bias = self.compute_bias(q_len, k_len) + position_bias = position_bias[:, :, -q_len:, :] + + if attention_mask is not None: + position_bias = position_bias + attention_mask + + attn = ( + jnp.einsum( + "binc,bjnc->bnij", + q, + k, + precision=Precision.HIGHEST, + ) + + position_bias + ) + + attn = jax.nn.softmax(attn.astype(jnp.float32), axis=-1).astype(attn.dtype) + + attn = self.dropout(attn) + + o_attn = jnp.einsum( + "bnij,bjnc->binc", + attn, + v, + precision=Precision.HIGHEST, + ) + + o_attn = o_attn.reshape(b, -1, n * c) + o_attn = self.o(o_attn) + o_attn = self.dropout(o_attn) + return o_attn + + +class UMT5LayerSelfAttention(nnx.Module): + def __init__( + self, + config: UMT5Config, + *, + layer_idx: Optional[int] = None, + param_dtype: jnp.dtype | None = jnp.float32, + rngs: nnx.Rngs, + ): + super().__init__() + self.SelfAttention = UMT5Attention( + config, + has_relative_attention_bias=True, + layer_idx=layer_idx, + param_dtype=param_dtype, + rngs=rngs, + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, param_dtype=param_dtype) + + def __call__( + self, + hidden_states: jax.Array, + attention_mask=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + ) + outputs = hidden_states + attention_output + return outputs + + +class UMT5LayerCrossAttention(nnx.Module): + def __init__( + self, + config: UMT5Config, + *, + layer_idx: Optional[int] = None, + param_dtype: jnp.dtype | None = jnp.float32, + rngs: nnx.Rngs, + ) -> None: + super().__init__() + self.EncDecAttention = UMT5Attention( + config, has_relative_attention_bias=False, layer_idx=layer_idx, param_dtype=param_dtype, rngs=rngs + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, param_dtype=param_dtype) + self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) + + def __call__( + self, + hidden_states: jax.Array, + encoder_hidden_states: jax.Array = None, + attention_mask: jax.Array = None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + return hidden_states + self.dropout(attention_output) + + +class UMT5Block(nnx.Module): + def __init__( + self, + config: UMT5Config, + *, + layer_idx: Optional[int] = None, + param_dtype: jnp.dtype | None = jnp.float32, + rngs: nnx.Rngs, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nnx.List() + self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx, param_dtype=param_dtype, rngs=rngs)) + if self.is_decoder: + self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx, param_dtype=param_dtype, rngs=rngs)) + + self.layer.append(UMT5LayerFF(config, param_dtype=param_dtype, rngs=rngs)) + + def __call__( + self, + hidden_states: jax.Array, + attention_mask: jax.Array = None, + encoder_hidden_states: jax.Array = None, + encoder_attention_mask: jax.Array = None, + ): + # Apply self-attention layer + hidden_states = fp16_clamp( + self.layer[0]( + hidden_states, + attention_mask=attention_mask, + ) + ) + # Cross-Attention Block + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + hidden_states = fp16_clamp( + self.layer[1]( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + ) + + # Apply Feed Forward layer + hidden_states = fp16_clamp(self.layer[-1](hidden_states)) + + return (hidden_states,) + + +class UMT5Stack(nnx.Module): + def __init__(self, config: UMT5Config, *, param_dtype: jnp.dtype | None = jnp.float32, rngs: nnx.Rngs): + super().__init__() + self.embed_tokens = nnx.Embed(config.vocab_size, config.d_model, param_dtype=param_dtype, rngs=rngs) + self.is_decoder = config.is_decoder + self.block = nnx.List( + [UMT5Block(config, layer_idx=i, param_dtype=param_dtype, rngs=rngs) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, param_dtype=param_dtype) + self.dropout = nnx.Dropout(config.dropout_rate, rngs=rngs) + + def _prepare_4d_causal_attention_mask_for_decoder( + self, + attention_mask: jax.Array, + batch_size: int, + q_len: int, + dtype: DTypeLike, + ): + if attention_mask is None: + causal_mask = jnp.tril(jnp.ones((batch_size, 1, q_len, q_len))) + elif attention_mask.ndim == 4: + causal_mask = attention_mask + elif attention_mask.ndim == 2: + # shape of causal_mask is (batch_size, q_len), expand to (batch_size, 1, q_len, q_len) + causal_mask = jnp.tril(attention_mask[:, None, None, :].repeat(q_len, axis=2)) + else: + raise ValueError(f"Invalid attention mask ndim, expected ndim: 2, actual ndim: {attention_mask.ndim}") + causal_mask = (1.0 - causal_mask) * jnp.finfo(dtype).min + return causal_mask + + def _prepare_padding_mask( + self, + padding_mask: jax.Array, + dtype: DTypeLike, + ): + """ + For decoder, padding mask is encoder attention mask. For encoder, padding mask is input mask from tokenizer. + """ + assert padding_mask.ndim in [2, 3], f"Invalid padding mask ndim: {padding_mask.ndim}" + + # expand dim if needed + if padding_mask.ndim == 3: + # shape of encoder_attention_mask is (batch_size, seq_len, seq_len), expand to (batch_size, 1, seq_len, seq_len) + padding_mask = padding_mask[:, None, :, :] + elif padding_mask.ndim == 2: + # shape of encoder_attention_mask is (batch_size, seq_len), expand to (batch_size, 1, 1, seq_len) + padding_mask = padding_mask[:, None, None, :] + + # convert to additive biases + padding_mask = (1.0 - padding_mask) * jnp.finfo(dtype).min + return padding_mask + + def _prepare_attention_mask( + self, + input_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + dtype: DTypeLike, + ): + """ + Prepare attention mask. Only SelfAttention mask is needed in encoder. Both SelfAttention mask and CrossAttention mask are needed in decoder. + Args: + input_ids: Indices of input sequence tokens in the vocabulary. shape: (batch_size, seq_len). + attention_mask: causal mask of input_ids. + encoder_hidden_states: The last hidden states of encoder output. + encoder_attention_mask: The causal mask of encoder_hidden_states + """ + if self.is_decoder: + b, s = input_ids.shape + # prepare self-attention causal mask for decoder + causal_mask = self._prepare_4d_causal_attention_mask_for_decoder(attention_mask, b, s, dtype) + # prepare cross-attention causal mask for decoder + if encoder_hidden_states is not None: + b, s, _ = encoder_hidden_states.shape + # new mask if not provided + if encoder_attention_mask is None: + encoder_attention_mask = jnp.ones((b, s), dtype=jnp.int32) + encoder_attention_mask = self._prepare_padding_mask(encoder_attention_mask, dtype) + else: + encoder_attention_mask = None + elif attention_mask is not None: + # prepare padding mask for encoder + causal_mask = self._prepare_padding_mask(attention_mask, dtype) + else: + causal_mask = None + + return causal_mask, encoder_attention_mask + + def __call__( + self, + input_ids: jax.Array = None, + attention_mask: jax.Array = None, + encoder_hidden_states: jax.Array = None, + encoder_attention_mask: jax.Array = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.dropout(inputs_embeds) + + # prepare attention mask for encoder and decoder + causal_mask, encoder_attention_mask = self._prepare_attention_mask( + input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds.dtype + ) + + for _, layer_module in enumerate(self.block): + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + hidden_states = layer_outputs[0] + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class UMT5EncoderModel(nnx.Module): + def __init__( + self, + config: UMT5Config, + *, + param_dtype: jnp.dtype | None = jnp.float32, + rngs: nnx.Rngs, + ): + super().__init__() + config.is_decoder = False + self.encoder = UMT5Stack(config, param_dtype=param_dtype, rngs=rngs) + + def __call__( + self, + input_ids: jax.Array = None, + attention_mask: jax.Array = None, + ) -> Union[jax.Array]: + r""" + input_ids (`jax.Array` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. + + To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training). + ```""" + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + return encoder_outputs + + +class UMT5Model(nnx.Module): + def __init__( + self, + config: UMT5Config, + *, + param_dtype: jnp.dtype | None = jnp.float32, + rngs: nnx.Rngs, + ): + super().__init__() + self.config = config + self.model_dim = config.d_model + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + self.encoder = UMT5Stack(encoder_config, param_dtype=param_dtype, rngs=rngs) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.num_layers = config.num_decoder_layers + self.decoder = UMT5Stack(decoder_config, param_dtype=param_dtype, rngs=rngs) + + self.lm_head = nnx.Linear( + config.d_model, + config.vocab_size, + use_bias=False, + precision=Precision.HIGHEST, + param_dtype=param_dtype, + rngs=rngs, + ) + + def __call__( + self, + input_ids: Optional[jax.Array] = None, + attention_mask: Optional[jax.Array] = None, + decoder_input_ids: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs: Optional[jax.Array] = None, + ) -> jax.Array: + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + hidden_states = encoder_outputs + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + ) + + return decoder_outputs + + def generate( + self, + input_ids: jax.Array, + attention_mask: jax.Array = None, + max_tokens: Optional[int] = None, + max_new_tokens: Optional[int] = None, + ) -> jax.Array: + """Generate sequences using greedy decoding. + + Args: + input_ids: Encoder input ids from tokenizer, shape (batch_size, seq_length) + attention_mask: Encoder attention mask, shape (batch_size, seq_length) + max_tokens: Maximum total length of decoder sequence (including start token). + Takes precedence over max_new_tokens if both are provided. + max_new_tokens: Maximum number of new tokens to generate (excluding start token) + + Returns: + Generated token ids, shape (batch_size, generated_length) + """ + # Determine maximum generation length + if max_tokens is not None: + max_length = max_tokens + elif max_new_tokens is not None: + max_length = max_new_tokens + 1 # +1 for decoder_start_token + else: + max_length = 512 # default value + + # Encode input + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + # Initialize decoder input with start token + batch_size = input_ids.shape[0] + decoder_input_ids = jnp.full((batch_size, 1), self.config.decoder_start_token_id, dtype=jnp.int32) + + # Autoregressive generation loop + for _ in range(max_length - 1): + # Decoder forward pass + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs, + ) + + # Get logits and select next token (greedy) + logits = self.lm_head(decoder_outputs) + # here use simple greedy, but beem search is recommended + next_token = jnp.argmax(logits[:, -1, :], axis=-1, keepdims=True) + + # Append to decoder input + decoder_input_ids = jnp.concatenate([decoder_input_ids, next_token], axis=1) + + # Stop if all sequences generated EOS + if jnp.all(next_token == self.config.eos_token_id): + break + + return decoder_input_ids + + +@jax.jit +def forward( + graphdef: nnx.GraphDef, + state: nnx.State, + *, + input_ids: Optional[jax.Array] = None, + attention_mask: Optional[jax.Array] = None, + decoder_input_ids: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs: Optional[jax.Array] = None, +) -> jax.Array: + model = nnx.merge(graphdef, state) + return model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + ) + + +__all__ = ["UMT5EncoderModel", "UMT5Model"] diff --git a/bonsai/models/umt5/params.py b/bonsai/models/umt5/params.py new file mode 100644 index 00000000..1df18582 --- /dev/null +++ b/bonsai/models/umt5/params.py @@ -0,0 +1,395 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import gc +import json +import logging +import re +from enum import Enum, auto + +import jax +import jax.numpy as jnp +import safetensors +import torch +from etils import epath +from flax import nnx + +from bonsai.models.umt5 import modeling as model_lib + + +def _get_key_and_transform_mapping(cls, cfg: model_lib.UMT5Config): + """Define mapping from HuggingFace UMT5 keys to JAX UMT5 keys.""" + + class Transform(Enum): + """Transformations for UMT5 parameters""" + + NONE = None + # For linear layers: (out, in) -> (in, out) + TRANSPOSE = ((1, 0), None, False) + + # T5/UMT5 uses standard HuggingFace naming + encoder_mapping = { + r"encoder\.embed_tokens\.weight": ("encoder.embed_tokens.embedding", Transform.NONE), + # Shared token embeddings + r"shared\.weight": ("encoder.embed_tokens.embedding", Transform.NONE), + # Encoder + # Encoder blocks - Self attention + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.q\.weight": ( + r"encoder.block.\1.layer.0.SelfAttention.q.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.k\.weight": ( + r"encoder.block.\1.layer.0.SelfAttention.k.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.v\.weight": ( + r"encoder.block.\1.layer.0.SelfAttention.v.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.o\.weight": ( + r"encoder.block.\1.layer.0.SelfAttention.o.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.relative_attention_bias\.weight": ( + r"encoder.block.\1.layer.0.SelfAttention.relative_attention_bias.embedding", + Transform.NONE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.layer_norm\.weight": ( + r"encoder.block.\1.layer.0.layer_norm.scale", + Transform.NONE, + ), + # Encoder blocks - Feed forward + r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wi_0\.weight": ( + r"encoder.block.\1.layer.1.DenseReluDense.wi_0.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wi_1\.weight": ( + r"encoder.block.\1.layer.1.DenseReluDense.wi_1.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wo\.weight": ( + r"encoder.block.\1.layer.1.DenseReluDense.wo.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.1\.layer_norm\.weight": ( + r"encoder.block.\1.layer.1.layer_norm.scale", + Transform.NONE, + ), + # Encoder Final layer norm + r"encoder\.final_layer_norm\.weight": ("encoder.final_layer_norm.scale", Transform.NONE), + } + decoder_mapping = { + # Decoder + # Decoder embedding + r"decoder\.embed_tokens\.weight": ("decoder.embed_tokens.embedding", Transform.NONE), + # Decoder blocks - Self attention + r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.q\.weight": ( + r"decoder.block.\1.layer.0.SelfAttention.q.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.k\.weight": ( + r"decoder.block.\1.layer.0.SelfAttention.k.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.v\.weight": ( + r"decoder.block.\1.layer.0.SelfAttention.v.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.o\.weight": ( + r"decoder.block.\1.layer.0.SelfAttention.o.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.relative_attention_bias\.weight": ( + r"decoder.block.\1.layer.0.SelfAttention.relative_attention_bias.embedding", + Transform.NONE, + ), + r"decoder\.block\.([0-9]+)\.layer\.0\.layer_norm\.weight": ( + r"decoder.block.\1.layer.0.layer_norm.scale", + Transform.NONE, + ), + # Decoder blocks - Cross attention + r"decoder\.block\.([0-9]+)\.layer\.1\.EncDecAttention\.q\.weight": ( + r"decoder.block.\1.layer.1.EncDecAttention.q.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.1\.EncDecAttention\.k\.weight": ( + r"decoder.block.\1.layer.1.EncDecAttention.k.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.1\.EncDecAttention\.v\.weight": ( + r"decoder.block.\1.layer.1.EncDecAttention.v.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.1\.EncDecAttention\.o\.weight": ( + r"decoder.block.\1.layer.1.EncDecAttention.o.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.1\.layer_norm\.weight": ( + r"decoder.block.\1.layer.1.layer_norm.scale", + Transform.NONE, + ), + # Decoder blocks - Feed forward + r"decoder\.block\.([0-9]+)\.layer\.2\.DenseReluDense\.wi_0\.weight": ( + r"decoder.block.\1.layer.2.DenseReluDense.wi_0.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.2\.DenseReluDense\.wi_1\.weight": ( + r"decoder.block.\1.layer.2.DenseReluDense.wi_1.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.2\.DenseReluDense\.wo\.weight": ( + r"decoder.block.\1.layer.2.DenseReluDense.wo.kernel", + Transform.TRANSPOSE, + ), + r"decoder\.block\.([0-9]+)\.layer\.2\.layer_norm\.weight": ( + r"decoder.block.\1.layer.2.layer_norm.scale", + Transform.NONE, + ), + # Decoder Final layer norm + r"decoder\.final_layer_norm\.weight": ("decoder.final_layer_norm.scale", Transform.NONE), + # lm head + r"lm_head\.weight": ("lm_head.kernel", Transform.TRANSPOSE), + } + + if cls == model_lib.UMT5EncoderModel: + return encoder_mapping + + full_mapping = encoder_mapping.copy() + full_mapping.update(decoder_mapping) + return full_mapping + + +def _torch_key_to_jax_key(mapping, source_key): + subs = [ + (re.sub(pat, repl, source_key), reshape) + for pat, (repl, reshape) in mapping.items() + if re.match(pat, source_key) + ] + if len(subs) > 1: + raise ValueError(f"Only one key should be found: {subs[0]}") + if len(subs) == 0: + return (None, None) + return subs[0] + + +def _assign_weights(keys, tensor, state_dict, st_key, transform, sharding_dict, dtype): + """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" + key, *rest = keys + if not rest: + if transform is not None: + permute, reshape, reshape_first = transform + if reshape_first and reshape is not None: + tensor = tensor.reshape(reshape) + if permute: + tensor = tensor.transpose(permute) + if not reshape_first and reshape is not None: + tensor = tensor.reshape(reshape) + if tensor.shape != state_dict[key].shape: + raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}") + # Only apply sharding if sharding_dict is provided + if sharding_dict is not None: + state_dict[key] = jax.device_put(tensor, sharding_dict[key]).astype(dtype) + else: + state_dict[key] = jax.device_put(tensor).astype(dtype) + else: + next_sharding = sharding_dict[key] if sharding_dict is not None else None + _assign_weights(rest, tensor, state_dict[key], st_key, transform, next_sharding, dtype) + + +def _stoi(s): + try: + return int(s) + except ValueError: + return s + + +class WeightFileType(Enum): + ST = auto() # safetensor + BIN = auto() + PT = auto() + PTH = auto() + + +def search_available_weight_file(file_path): + p = epath.Path(file_path).expanduser() + if p.is_file(): + # if file_path is a file, return [p], type + file_ext = p.suffix.lower() + if file_ext == ".safetensors": + print(f"Using {p} to load weight") + return [p], WeightFileType.ST + elif file_ext == ".bin": + print(f"Using {p} to load weight") + return [p], WeightFileType.BIN + elif file_ext == ".pt": + print(f"Using {p} to load weight") + return [p], WeightFileType.PT + elif file_ext == ".pth": + print(f"Using {p} to load weight") + return [p], WeightFileType.PTH + else: + raise ValueError(f"Unsupported file extension: {file_ext}") + else: + # if file_path is dir, return all matching files + files = list(p.glob("*.safetensors")) + if not files: + logging.warning(f"No *.safetensors found in {file_path}, try to search others") + else: + print("Using *.safetensors to load weight") + return files, WeightFileType.ST + + files = list(p.glob("*.bin")) + if not files: + logging.warning(f"No *.bin found in {file_path}, try to search others") + else: + print("Using *.bin to load weight") + return files, WeightFileType.BIN + + files = list(p.glob("*.pth")) + if not files: + logging.warning(f"No *.pth found in {file_path}, try to search others") + else: + print("Using *.pth to load weight") + return files, WeightFileType.PTH + + raise ValueError(f"No weight file found in {file_path}") + + +def open_weight_file(f, file_type): + if file_type in [WeightFileType.BIN, WeightFileType.PTH]: + sf = torch.load(f, map_location="cpu") + elif file_type == WeightFileType.ST: + sf = safetensors.safe_open(f, framework="numpy") + else: + raise ValueError(f"invalid file type: {file_type}") + return sf + + +def get_tensor(sf, torch_key, file_type): + if file_type in [WeightFileType.BIN, WeightFileType.PTH]: + tensor = sf[torch_key] + elif file_type == WeightFileType.ST: + tensor = sf.get_tensor(torch_key) + else: + raise ValueError(f"invalid file type: {file_type}") + + return tensor + + +def create_model( + cls, + file_dir: str, + cfg: model_lib.UMT5Config, + key_mapping=None, + param_dtype: jnp.dtype | None = jnp.float32, + mesh: jax.sharding.Mesh | None = None, +) -> model_lib.UMT5Model | model_lib.UMT5EncoderModel: + """Load weight and create a UMT5Encoder model (memory-optimized). + + Args: + cls: model class. UMT5Model and UMT5EncoderModel is available. + file_dir: model weight path. + cfg: model config. Use 'load_model_config' to get it. + key_mapping: model weight key map. Used in unofficial umt5 model, such as Wan2.1 + param_dtype: model weight dtype. + mesh: model weight mesh. + Returns: + The instance of model defined in cls. + """ + files, file_type = search_available_weight_file(file_dir) + + umt5 = nnx.eval_shape(lambda: cls(cfg, param_dtype=param_dtype, rngs=nnx.Rngs(params=0, dropout=0))) + graph_def, abs_state = nnx.split(umt5) + state_dict = abs_state.to_pure_dict() + # Only use sharding if mesh is provided + sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None + + if not key_mapping: + key_mapping = _get_key_and_transform_mapping(cls, cfg) + conversion_errors = [] + + print(f"Loading Weight: {cfg.dtype=}") + for f in files: + sf = open_weight_file(f, file_type) + + for torch_key in sf.keys(): + ts = get_tensor(sf, torch_key, file_type) + if isinstance(ts, torch.Tensor): + npy = ts.numpy() if ts.dtype != torch.bfloat16 else ts.to(dtype=torch.float32).numpy() + + jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key) + if jax_key is None: + continue + + keys = [_stoi(k) for k in jax_key.split(".")] + try: + _assign_weights(keys, npy, state_dict, torch_key, transform.value, sharding, cfg.dtype) + except Exception as e: + full_jax_key = ".".join([str(k) for k in keys]) + conversion_errors.append(f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}") + gc.collect() + + if conversion_errors: + full_error_log = "\n".join(conversion_errors) + raise RuntimeError(f"Encountered {len(conversion_errors)} weight conversion errors. Log:\n{full_error_log}") + + if cls == model_lib.UMT5Model: + state_dict["decoder"]["embed_tokens"]["embedding"] = state_dict["encoder"]["embed_tokens"]["embedding"] + + gc.collect() + m = nnx.merge(graph_def, state_dict) + m.eval() + return m + + +def get_weight_dtype_from_config(conf_dict): + def get_dtype(dtype_str): + if dtype_str in ["float32", "fp32"]: + return jnp.float32 + elif dtype_str in ["bloat16", "bf16"]: + return jnp.bfloat16 + elif dtype_str in ["float16", "fp16"]: + return jnp.float16 + else: + logging.warning(f"Unrecognized dtype: {dtype_str}") + return jnp.float32 + + if "dtype" in conf_dict: + return get_dtype(conf_dict["dtype"]) + elif "torch_dtype" in conf_dict: + return get_dtype(conf_dict["torch_dtype"]) + else: + logging.warning("No 'dtype' config found in config file") + return jnp.float32 + + +def load_model_config(model_path: str) -> model_lib.UMT5Config: + """Load the model config from the model path.""" + model_dir = epath.Path(model_path).expanduser() + config_path = model_dir / "config.json" + + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found at {config_path}") + + with config_path.open("r") as f: + config_dict = json.load(f) + + dtype = get_weight_dtype_from_config(config_dict) + + # Filter config_dict to only include fields defined in UMT5Config + config_fields = {f.name for f in dataclasses.fields(model_lib.UMT5Config)} + filtered_config = {k: v for k, v in config_dict.items() if k in config_fields} + + return model_lib.UMT5Config(**filtered_config, dtype=dtype) diff --git a/bonsai/models/umt5/tests/__init__.py b/bonsai/models/umt5/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bonsai/models/umt5/tests/run_model.py b/bonsai/models/umt5/tests/run_model.py new file mode 100644 index 00000000..84d081a4 --- /dev/null +++ b/bonsai/models/umt5/tests/run_model.py @@ -0,0 +1,75 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Demo script to run UMT5 model inference.""" + +import jax +import jax.numpy as jnp +from flax import nnx +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer + +from bonsai.models.umt5.modeling import UMT5Model, forward +from bonsai.models.umt5.params import create_model, load_model_config + + +def main(): + """Run UMT5 model inference demo.""" + print("=" * 80) + print("UMT5 Model Demo - JAX Implementation") + print("=" * 80) + + # Model configuration + model_name = "google/umt5-base" + model_ckpt_path = snapshot_download(model_name) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Load model config and create model + model_conf = load_model_config(model_ckpt_path) + + jax_model = create_model( + UMT5Model, + file_dir=model_ckpt_path, + cfg=model_conf, + ) + graphdef, state = nnx.split(jax_model) + + # Prepare input + prompts = [ + "A beautiful sunset over the ocean with waves crashing on the shore", + "translate to French: I love cat", + ] + + # Tokenize input + inputs = tokenizer(prompts, padding=True, return_tensors="np") + input_ids = jnp.array(inputs.input_ids) + attention_mask = jnp.array(inputs.attention_mask) + + # forward + bs = len(prompts) + decoder_input_ids = jnp.full((bs, 1), model_conf.decoder_start_token_id, dtype=jnp.int32) + decoder_output = forward( + graphdef, + state, + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + ) + print(f"Decoder output shape: {decoder_output.shape}") + + +if __name__ == "__main__": + main() diff --git a/bonsai/models/umt5/tests/test_umt5.py b/bonsai/models/umt5/tests/test_umt5.py new file mode 100644 index 00000000..5903fa9b --- /dev/null +++ b/bonsai/models/umt5/tests/test_umt5.py @@ -0,0 +1,160 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test accuracy of jax impelement vs torch transformer impelement.""" + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from absl.testing import absltest +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer +from transformers import UMT5EncoderModel as TorchUMT5EncoderModel +from transformers import UMT5Model as TorchUMT5Model + +from bonsai.models.umt5.modeling import UMT5Config, UMT5EncoderModel, UMT5Model +from bonsai.models.umt5.params import create_model, load_model_config + + +def compare_outputs(jax_output: jax.Array, torch_output, name: str, rtol: float = 1e-3, atol: float = 1e-5): + """Compare JAX and PyTorch outputs and report differences. + Args: + jax_output: Output from JAX model + torch_output: Output from PyTorch model (torch.Tensor) + name: Name of the output being compared + rtol: Relative tolerance + atol: Absolute tolerance + """ + if torch_output.dtype == torch.bfloat16: + torch_output = torch_output.float() + if jax_output.dtype == jnp.bfloat16: + jax_output = jax_output.astype(jnp.float32) + + # Convert PyTorch to numpy + if isinstance(torch_output, torch.Tensor): + torch_np = torch_output.detach().cpu().numpy() + else: + torch_np = np.array(torch_output) + + # Convert JAX to numpy + jax_np = np.array(jax_output) + + # Check shapes match + if jax_np.shape != torch_np.shape: + print(f"❌ Shape mismatch! jax shape: {jax_np.shape}, torch shape: {torch_np.shape}") + return False + + if jax_np.dtype != torch_np.dtype: + print(f"❌ Shape mismatch! jax dtype: {jax_np.dtype}, torch dtype: {torch_np.dtype}") + return False + + # Check if within tolerance + close = np.allclose(jax_np, torch_np, rtol=rtol, atol=atol) + + if not close: + print(f"\n❌ Outputs do NOT match (rtol={rtol}, atol={atol})") + # Show some mismatched locations + mismatch_mask = ~np.isclose(jax_np, torch_np, rtol=rtol, atol=atol) + n_mismatches = np.sum(mismatch_mask) + print(f" Number of mismatches: {n_mismatches} / {jax_np.size} ({100 * n_mismatches / jax_np.size:.2f}%)") + + return close + + +class UMT5Test(absltest.TestCase): + def setUp(self): + super().setUp() + self.model_name = "google/umt5-base" + self.tokenizer_name = "google/umt5-base" + + self.model_ckpt_path = snapshot_download(self.model_name) + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + self.model_config = load_model_config(self.model_ckpt_path) + + def test_t5_encoder_accuracy(self): + jax_t5 = create_model( + UMT5EncoderModel, + file_dir=self.model_ckpt_path, + cfg=self.model_config, + ) + + hf_t5 = TorchUMT5EncoderModel.from_pretrained(self.model_ckpt_path) + + prompts = [ + "A beautiful sunset over the ocean with waves crashing on the shore", + "translate to French: I love cat", + ] + torch_inputs = self.tokenizer(prompts, padding=True, return_tensors="pt") + jax_inputs = self.tokenizer(prompts, padding=True, return_tensors="np") + + # test encoder accuracy + pytorch_output = hf_t5.encoder(input_ids=torch_inputs.input_ids, attention_mask=torch_inputs.attention_mask) + jax_output = jax_t5.encoder( + input_ids=jnp.array(jax_inputs.input_ids), attention_mask=jnp.array(jax_inputs.attention_mask) + ) + + torch_embeddings = pytorch_output.last_hidden_state + + seq_lens = torch_inputs.attention_mask.gt(0).sum(dim=1).long() + for i in range(jax_output.shape[0]): + self.assertTrue( + compare_outputs( + jax_output[i, : seq_lens[i], :], + torch_embeddings[i, : seq_lens[i], :], + f"UMT5 Encoder For Prompt: {prompts[i]}", + ) + ) + + def test_t5_decoder_accuracy(self): + jax_t5 = create_model( + UMT5Model, + file_dir=self.model_ckpt_path, + cfg=self.model_config, + ) + + hf_t5 = TorchUMT5Model.from_pretrained(self.model_ckpt_path) + + prompts = [ + "A beautiful sunset over the ocean with waves crashing on the shore", + "translate to French: I love cat", + ] + torch_inputs = self.tokenizer(prompts, padding=True, return_tensors="pt") + + # test encoder accuracy + pytorch_output = hf_t5.encoder(input_ids=torch_inputs.input_ids, attention_mask=torch_inputs.attention_mask) + encoder_hidden_states = pytorch_output.last_hidden_state + + # test decoder accuracy + # use torch encoder output as docoder input + bs = encoder_hidden_states.shape[0] + decoder_input_ids = [0, 1, 2, 3] + torch_input_ids = torch.tensor(decoder_input_ids, dtype=torch.int32).unsqueeze(0).repeat(bs, 1) + pytorch_decoder_output = hf_t5.decoder( + input_ids=torch_input_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=torch_inputs.attention_mask, + ) + decoder_hidden_states = pytorch_decoder_output.last_hidden_state + + jax_decoder_output = jax_t5.decoder( + input_ids=jnp.array(torch_input_ids.numpy()), + encoder_hidden_states=jnp.array(encoder_hidden_states.detach().numpy()), + encoder_attention_mask=jnp.array(torch_inputs.attention_mask.detach().numpy()), + ) + self.assertTrue(compare_outputs(jax_decoder_output, decoder_hidden_states, "UMT5 Decoder")) + + +if __name__ == "__main__": + absltest.main() diff --git a/bonsai/models/wan2/__init__.py b/bonsai/models/wan2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bonsai/models/wan2/params.py b/bonsai/models/wan2/params.py new file mode 100644 index 00000000..957fa6e5 --- /dev/null +++ b/bonsai/models/wan2/params.py @@ -0,0 +1,706 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Weight loading utilities for Wan2.1-T2V-1.3B model.""" + +import gc +import re +from enum import Enum + +import jax +import jax.numpy as jnp +import safetensors +from etils import epath +from flax import nnx + +from bonsai.models.wan2 import transformer_wan as model_lib +from bonsai.models.wan2 import umt5 as t5_lib +from bonsai.models.wan2 import vae_wan as vae_lib + + +def cast_with_exclusion(path, x, dtype_to_cast): + """ + Casts arrays to dtype_to_cast, but keeps params from any 'norm' layer in float32. + """ + + exclusion_keywords = [ + "norm", # For all LayerNorm/GroupNorm layers + "time_embed", # The entire time conditioning module + "text_proj", # The entire text conditioning module + "scale_shift_table", # Catches both the final and the AdaLN tables + ] + + path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path) + + if any(keyword in path_str.lower() for keyword in exclusion_keywords): + # Keep LayerNorm/GroupNorm weights and biases in full precision + return x.astype(jnp.float32) + else: + # Cast everything else to dtype_to_cast + return x.astype(dtype_to_cast) + + +def _get_dit_mapping(cfg: model_lib.TransformerWanModelConfig): + class Transform(Enum): + """Transformations for model parameters""" + + NONE = None + TRANSPOSE = ((1, 0), None) # For linear layers: (out, in) -> (in, out) + TRANSPOSE_CONV = ((2, 3, 4, 1, 0), None) # For 3D conv: (out, in, t, h, w) -> (t, h, w, in, out) + + mapping = { + # Patch embedding (input projection) + r"patch_embedding\.weight": ("patch_embed.kernel", Transform.TRANSPOSE_CONV), + r"patch_embedding\.bias": ("patch_embed.bias", Transform.NONE), + # Time embedder - Sequential uses integer indices (0, 1, 2), not layers_0 + r"condition_embedder\.time_embedder\.linear_1\.weight": ( + "time_embed.time_embedding.layers.0.kernel", + Transform.TRANSPOSE, + ), + r"condition_embedder\.time_embedder\.linear_1\.bias": ( + "time_embed.time_embedding.layers.0.bias", + Transform.NONE, + ), + r"condition_embedder\.time_embedder\.linear_2\.weight": ( + "time_embed.time_embedding.layers.2.kernel", + Transform.TRANSPOSE, + ), + r"condition_embedder\.time_embedder\.linear_2\.bias": ( + "time_embed.time_embedding.layers.2.bias", + Transform.NONE, + ), + r"condition_embedder\.time_proj\.weight": ("time_embed.time_projection.layers.1.kernel", Transform.TRANSPOSE), + r"condition_embedder\.time_proj\.bias": ("time_embed.time_projection.layers.1.bias", Transform.NONE), + # Text embedder (projects UMT5 embeddings to hidden dim) + r"condition_embedder\.text_embedder\.linear_1\.weight": ("text_proj.layers.0.kernel", Transform.TRANSPOSE), + r"condition_embedder\.text_embedder\.linear_1\.bias": ("text_proj.layers.0.bias", Transform.NONE), + r"condition_embedder\.text_embedder\.linear_2\.weight": ("text_proj.layers.2.kernel", Transform.TRANSPOSE), + r"condition_embedder\.text_embedder\.linear_2\.bias": ("text_proj.layers.2.bias", Transform.NONE), + # Transformer blocks - Self attention (attn1) + r"blocks\.([0-9]+)\.attn1\.norm_q\.weight": (r"blocks.\1.self_attn.q_norm.scale", Transform.NONE), + r"blocks\.([0-9]+)\.attn1\.norm_k\.weight": (r"blocks.\1.self_attn.k_norm.scale", Transform.NONE), + r"blocks\.([0-9]+)\.attn1\.to_q\.weight": (r"blocks.\1.self_attn.q_proj.kernel", Transform.TRANSPOSE), + r"blocks\.([0-9]+)\.attn1\.to_q\.bias": (r"blocks.\1.self_attn.q_proj.bias", Transform.NONE), + r"blocks\.([0-9]+)\.attn1\.to_k\.weight": (r"blocks.\1.self_attn.k_proj.kernel", Transform.TRANSPOSE), + r"blocks\.([0-9]+)\.attn1\.to_k\.bias": (r"blocks.\1.self_attn.k_proj.bias", Transform.NONE), + r"blocks\.([0-9]+)\.attn1\.to_v\.weight": (r"blocks.\1.self_attn.v_proj.kernel", Transform.TRANSPOSE), + r"blocks\.([0-9]+)\.attn1\.to_v\.bias": (r"blocks.\1.self_attn.v_proj.bias", Transform.NONE), + r"blocks\.([0-9]+)\.attn1\.to_out\.0\.weight": (r"blocks.\1.self_attn.out_proj.kernel", Transform.TRANSPOSE), + r"blocks\.([0-9]+)\.attn1\.to_out\.0\.bias": (r"blocks.\1.self_attn.out_proj.bias", Transform.NONE), + # Transformer blocks - Cross attention (attn2) + # Note: CrossAttention only has q_norm, not k_norm; norm_k is skipped + r"blocks\.([0-9]+)\.attn2\.norm_q\.weight": (r"blocks.\1.cross_attn.q_norm.scale", Transform.NONE), + r"blocks\.([0-9]+)\.attn2\.norm_k\.weight": (r"blocks.\1.cross_attn.k_norm.scale", Transform.NONE), + r"blocks\.([0-9]+)\.attn2\.to_q\.weight": (r"blocks.\1.cross_attn.q_proj.kernel", Transform.TRANSPOSE), + r"blocks\.([0-9]+)\.attn2\.to_q\.bias": (r"blocks.\1.cross_attn.q_proj.bias", Transform.NONE), + # Note: to_k and to_v need special handling - they're fused into kv_proj in JAX + # See _load_fused_kv_weights() below + r"blocks\.([0-9]+)\.attn2\.to_out\.0\.weight": (r"blocks.\1.cross_attn.out_proj.kernel", Transform.TRANSPOSE), + r"blocks\.([0-9]+)\.attn2\.to_out\.0\.bias": (r"blocks.\1.cross_attn.out_proj.bias", Transform.NONE), + # Transformer blocks - Feed forward (Sequential creates 'layers' dict with 0, 2 keys) + r"blocks\.([0-9]+)\.ffn\.net\.0\.proj\.weight": (r"blocks.\1.mlp.layers.0.kernel", Transform.TRANSPOSE), + r"blocks\.([0-9]+)\.ffn\.net\.0\.proj\.bias": (r"blocks.\1.mlp.layers.0.bias", Transform.NONE), + r"blocks\.([0-9]+)\.ffn\.net\.2\.weight": (r"blocks.\1.mlp.layers.2.kernel", Transform.TRANSPOSE), + r"blocks\.([0-9]+)\.ffn\.net\.2\.bias": (r"blocks.\1.mlp.layers.2.bias", Transform.NONE), + # Transformer blocks - Norm and modulation + r"blocks\.([0-9]+)\.norm2\.weight": (r"blocks.\1.norm2.scale", Transform.NONE), + r"blocks\.([0-9]+)\.norm2\.bias": (r"blocks.\1.norm2.bias", Transform.NONE), + r"blocks\.([0-9]+)\.scale_shift_table": (r"blocks.\1.scale_shift_table", Transform.NONE), + # Output projection + r"scale_shift_table": ("final_layer.scale_shift_table", Transform.NONE), + r"proj_out\.weight": ("final_layer.linear.kernel", Transform.TRANSPOSE), + r"proj_out\.bias": ("final_layer.linear.bias", Transform.NONE), + r"norm_out\.weight": ("final_layer.norm.scale", Transform.NONE), + } + + return mapping + + +def _get_vae_key_mapping(): + """Define mapping from PyTorch VAE keys to JAX VAE keys.""" + + class Transform(Enum): + """Transformations for VAE parameters""" + + NONE = None + TRANSPOSE_2D_CONV = ((2, 3, 1, 0), None) # For 2D conv: (out, in, h, w) -> (h, w, in, out) + TRANSPOSE_3D = ((2, 3, 4, 1, 0), None) # For 3D conv: (out, in, t, h, w) -> (t, h, w, in, out) + SQUEEZE = (None, (-1,)) # Squeeze to 1D: (C, 1, 1, 1) -> (C,) + + # PyTorch format: (out_channels, in_channels, kernel_size...) + # JAX format: (kernel_size..., in_channels, out_channels) + mapping = { + # Post-quantization conv: 1x1x1 conv + r"post_quant_conv\.weight": ("conv2.conv.kernel", Transform.TRANSPOSE_3D), + r"post_quant_conv\.bias": ("conv2.conv.bias", Transform.NONE), + # Decoder input conv + r"decoder\.conv_in\.weight": ("decoder.conv_in.conv.kernel", Transform.TRANSPOSE_3D), + r"decoder\.conv_in\.bias": ("decoder.conv_in.conv.bias", Transform.NONE), + # Mid block resnets + r"decoder\.mid_block\.resnets\.0\.norm1\.gamma": ("decoder.mid_block1.norm1.scale", Transform.SQUEEZE), + r"decoder\.mid_block\.resnets\.0\.conv1\.weight": ( + "decoder.mid_block1.conv1.conv.kernel", + Transform.TRANSPOSE_3D, + ), + r"decoder\.mid_block\.resnets\.0\.conv1\.bias": ("decoder.mid_block1.conv1.conv.bias", Transform.NONE), + r"decoder\.mid_block\.resnets\.0\.norm2\.gamma": ("decoder.mid_block1.norm2.scale", Transform.SQUEEZE), + r"decoder\.mid_block\.resnets\.0\.norm2\.bias": ("decoder.mid_block1.norm2.scale", Transform.NONE), + r"decoder\.mid_block\.resnets\.0\.conv2\.weight": ( + "decoder.mid_block1.conv2.conv.kernel", + Transform.TRANSPOSE_3D, + ), + r"decoder\.mid_block\.resnets\.0\.conv2\.bias": ("decoder.mid_block1.conv2.conv.bias", Transform.NONE), + r"decoder\.mid_block\.resnets\.1\.norm1\.gamma": ("decoder.mid_block2.norm1.scale", Transform.SQUEEZE), + r"decoder\.mid_block\.resnets\.1\.conv1\.weight": ( + "decoder.mid_block2.conv1.conv.kernel", + Transform.TRANSPOSE_3D, + ), + r"decoder\.mid_block\.resnets\.1\.conv1\.bias": ("decoder.mid_block2.conv1.conv.bias", Transform.NONE), + r"decoder\.mid_block\.resnets\.1\.norm2\.gamma": ("decoder.mid_block2.norm2.scale", Transform.SQUEEZE), + r"decoder\.mid_block\.resnets\.1\.conv2\.weight": ( + "decoder.mid_block2.conv2.conv.kernel", + Transform.TRANSPOSE_3D, + ), + r"decoder\.mid_block\.resnets\.1\.conv2\.bias": ("decoder.mid_block2.conv2.conv.bias", Transform.NONE), + # Mid attention block + r"decoder\.mid_block\.attentions\.0\.norm\.gamma": ("decoder.mid_attn.norm.scale", Transform.SQUEEZE), + r"decoder\.mid_block\.attentions\.0\.to_qkv\.weight": ( + "decoder.mid_attn.qkv.kernel", + Transform.TRANSPOSE_2D_CONV, + ), + r"decoder\.mid_block\.attentions\.0\.to_qkv\.bias": ("decoder.mid_attn.qkv.bias", Transform.NONE), + r"decoder\.mid_block\.attentions\.0\.proj\.weight": ( + "decoder.mid_attn.proj.kernel", + Transform.TRANSPOSE_2D_CONV, + ), + r"decoder\.mid_block\.attentions\.0\.proj\.bias": ("decoder.mid_attn.proj.bias", Transform.NONE), + # Up blocks - resnets (pattern for all 4 stages, 3 resnets each) + r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.norm1\.gamma": ( + r"decoder.up_blocks_\1.\2.norm1.scale", + Transform.SQUEEZE, + ), + r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.conv1\.weight": ( + r"decoder.up_blocks_\1.\2.conv1.conv.kernel", + Transform.TRANSPOSE_3D, + ), + r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.conv1\.bias": ( + r"decoder.up_blocks_\1.\2.conv1.conv.bias", + Transform.NONE, + ), + r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.norm2\.gamma": ( + r"decoder.up_blocks_\1.\2.norm2.scale", + Transform.SQUEEZE, + ), + r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.conv2\.weight": ( + r"decoder.up_blocks_\1.\2.conv2.conv.kernel", + Transform.TRANSPOSE_3D, + ), + r"decoder\.up_blocks\.([0-3])\.resnets\.([0-2])\.conv2\.bias": ( + r"decoder.up_blocks_\1.\2.conv2.conv.bias", + Transform.NONE, + ), + # Skip connections (only in block 1, resnet 0) + r"decoder\.up_blocks\.1\.resnets\.0\.conv_shortcut\.weight": ( + "decoder.up_blocks_1.0.skip_conv.conv.kernel", + Transform.TRANSPOSE_3D, + ), + r"decoder\.up_blocks\.1\.resnets\.0\.conv_shortcut\.bias": ( + "decoder.up_blocks_1.0.skip_conv.conv.bias", + Transform.NONE, + ), + # Upsamplers for blocks 0, 1, 2 (block 3 has no upsampler) + # Block 0: Upsample3D (time_conv + spatial_conv) + r"decoder\.up_blocks\.0\.upsamplers\.0\.time_conv\.weight": ( + "decoder.up_sample_0.time_conv.conv.kernel", + Transform.TRANSPOSE_3D, + ), + r"decoder\.up_blocks\.0\.upsamplers\.0\.time_conv\.bias": ( + "decoder.up_sample_0.time_conv.conv.bias", + Transform.NONE, + ), + r"decoder\.up_blocks\.0\.upsamplers\.0\.resample\.1\.weight": ( + "decoder.up_sample_0.spatial_conv.kernel", + Transform.TRANSPOSE_2D_CONV, + ), + r"decoder\.up_blocks\.0\.upsamplers\.0\.resample\.1\.bias": ( + "decoder.up_sample_0.spatial_conv.bias", + Transform.NONE, + ), + # Block 1: Upsample3D (time_conv + spatial_conv) + r"decoder\.up_blocks\.1\.upsamplers\.0\.time_conv\.weight": ( + "decoder.up_sample_1.time_conv.conv.kernel", + Transform.TRANSPOSE_3D, + ), + r"decoder\.up_blocks\.1\.upsamplers\.0\.time_conv\.bias": ( + "decoder.up_sample_1.time_conv.conv.bias", + Transform.NONE, + ), + r"decoder\.up_blocks\.1\.upsamplers\.0\.resample\.1\.weight": ( + "decoder.up_sample_1.spatial_conv.kernel", + Transform.TRANSPOSE_2D_CONV, + ), + r"decoder\.up_blocks\.1\.upsamplers\.0\.resample\.1\.bias": ( + "decoder.up_sample_1.spatial_conv.bias", + Transform.NONE, + ), + # Block 2: Upsample2D (conv only, no time_conv) + r"decoder\.up_blocks\.2\.upsamplers\.0\.resample\.1\.weight": ( + "decoder.up_sample_2.conv.kernel", + Transform.TRANSPOSE_2D_CONV, + ), + r"decoder\.up_blocks\.2\.upsamplers\.0\.resample\.1\.bias": ("decoder.up_sample_2.conv.bias", Transform.NONE), + # Output layers + r"decoder\.norm_out\.gamma": ("decoder.norm_out.scale", Transform.SQUEEZE), + r"decoder\.conv_out\.weight": ("decoder.conv_out.conv.kernel", Transform.TRANSPOSE_3D), + r"decoder\.conv_out\.bias": ("decoder.conv_out.conv.bias", Transform.NONE), + } + + return mapping + + +def _torch_key_to_jax_key(mapping, source_key): + """Convert a PyTorch/Diffusers key to JAX key with transform info.""" + subs = [ + (re.sub(pat, repl, source_key), transform) + for pat, (repl, transform) in mapping.items() + if re.match(pat, source_key) + ] + if len(subs) == 0: + # Key not found in mapping, might be OK (e.g., VAE weights) + return None, None + if len(subs) > 1: + raise ValueError(f"Multiple patterns matched for key {source_key}: {subs}") + return subs[0] + + +def _assign_weights(keys, tensor, state_dict, st_key, transform, sharding_dict=None): + """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" + key, *rest = keys + if not rest: + if transform is not None and transform.value is not None: + permute, reshape = transform.value + if reshape is not None: + tensor = tensor.reshape(reshape) + if permute: + tensor = tensor.transpose(permute) + + if key not in state_dict: + raise KeyError(f"Key {key} not found in state_dict. Available keys: {list(state_dict.keys())[:10]}...") + + if tensor.shape != state_dict[key].shape: + raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}") + + # Assign with or without sharding + if sharding_dict is not None and key in sharding_dict: + state_dict[key] = jax.device_put(tensor, sharding_dict[key]) + else: + state_dict[key] = jax.device_put(tensor) + else: + next_sharding = sharding_dict[key] if sharding_dict is not None and key in sharding_dict else None + _assign_weights(rest, tensor, state_dict[key], st_key, transform, next_sharding) + + +def _stoi(s): + """Convert string to int if possible, otherwise return string.""" + try: + return int(s) + except ValueError: + return s + + +def create_model_from_safe_tensors( + file_dir: str, + cfg: model_lib.TransformerWanModelConfig, + mesh: jax.sharding.Mesh | None = None, +) -> model_lib.Wan2DiT: + """ + Load Wan2.1-T2V-1.3B DiT model from safetensors checkpoint. + + Args: + file_dir: Directory containing .safetensors files or path to transformer directory + cfg: Model configuration + mesh: Optional JAX mesh for sharding + load_transformer_only: If True, only load transformer weights (not VAE/text encoder) + + Returns: + Wan2DiT model with loaded weights + """ + # Check if file_dir is the model root or transformer subdirectory + file_path = epath.Path(file_dir).expanduser() + transformer_path = file_path / "transformer" + + if transformer_path.exists(): + # Look in transformer subdirectory + files = sorted(list(transformer_path.glob("diffusion_pytorch_model-*.safetensors"))) + else: + # Look in provided directory + files = sorted(list(file_path.glob("diffusion_pytorch_model-*.safetensors"))) + if not files: + files = sorted(list(file_path.glob("*.safetensors"))) + + if not files: + raise ValueError(f"No safetensors found in {file_dir} or {file_dir}/transformer") + + print(f"Found {len(files)} DiT transformer safetensors file(s)") + + # Create model structure + wan2_dit = nnx.eval_shape(lambda: model_lib.Wan2DiT(cfg, rngs=nnx.Rngs(params=0))) + graph_def, abs_state = nnx.split(wan2_dit) + state_dict = abs_state.to_pure_dict() + + # Setup sharding if mesh provided + sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None + + key_mapping = _get_dit_mapping(cfg) + conversion_errors = [] + loaded_keys = [] + skipped_keys = [] + + # Collect K/V weights for fusion into kv_proj + kv_weights = {} # {block_idx: {'k_weight': ..., 'k_bias': ..., 'v_weight': ..., 'v_bias': ...}} + + for f in files: + print(f"Loading weights from {f.name}...") + with safetensors.safe_open(f, framework="numpy") as sf: + for torch_key in sf.keys(): + tensor = sf.get_tensor(torch_key) + + # Special handling for cross-attention K/V fusion + kv_match = re.match(r"blocks\.([0-9]+)\.attn2\.to_([kv])\.(weight|bias)", torch_key) + if kv_match: + block_idx = int(kv_match.group(1)) + kv_type = kv_match.group(2) # 'k' or 'v' + param_type = kv_match.group(3) # 'weight' or 'bias' + + if block_idx not in kv_weights: + kv_weights[block_idx] = {} + kv_weights[block_idx][f"{kv_type}_{param_type}"] = tensor + loaded_keys.append(torch_key) + continue + + jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key) + + if jax_key is None: + # Skip keys not in our mapping (e.g., VAE, text encoder, attn2.norm_k) + skipped_keys.append(torch_key) + continue + + keys = [_stoi(k) for k in jax_key.split(".")] + try: + _assign_weights(keys, tensor, state_dict, torch_key, transform, sharding) + loaded_keys.append(torch_key) + except Exception as e: + full_jax_key = ".".join([str(k) for k in keys]) + conversion_errors.append( + f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}" + ) + gc.collect() + + # Fuse collected K/V weights into kv_proj + import jax.numpy as jnp + + for block_idx, weights in kv_weights.items(): + if all(k in weights for k in ["k_weight", "k_bias", "v_weight", "v_bias"]): + # Transpose and concatenate: (out, in) -> (in, out) then concat -> (in, 2*out) + k_weight = weights["k_weight"].T # (in, out) + v_weight = weights["v_weight"].T # (in, out) + kv_kernel = jnp.concatenate([k_weight, v_weight], axis=1) # (in, 2*out) + + kv_bias = jnp.concatenate([weights["k_bias"], weights["v_bias"]]) # (2*out,) + + # Assign to state dict + state_dict["blocks"][block_idx]["cross_attn"]["kv_proj"]["kernel"] = jax.device_put(kv_kernel) + state_dict["blocks"][block_idx]["cross_attn"]["kv_proj"]["bias"] = jax.device_put(kv_bias) + + print(f"Loaded {len(loaded_keys)} weight tensors") + print(f"Skipped {len(skipped_keys)} weight tensors (VAE/text encoder/attn2.norm_k)") + + state_dict = jax.tree_util.tree_map_with_path( + lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=cfg.weights_dtype), state_dict + ) + + if conversion_errors: + print(f"\n Warning: {len(conversion_errors)} conversion errors occurred:") + for err in conversion_errors: # Show first 5 errors + print(f" {err}") + # if len(conversion_errors) > 5: + # print(f" ... and {len(conversion_errors) - 5} more") + + gc.collect() + return nnx.merge(graph_def, state_dict) + + +def create_vae_decoder_from_safe_tensors( + file_dir: str, + mesh: jax.sharding.Mesh | None = None, +) -> vae_lib.WanVAEDecoder: + """ + Load Wan-VAE decoder from safetensors checkpoint. + + Args: + file_dir: Directory containing .safetensors files or path to VAE directory + mesh: Optional JAX mesh for sharding + + Returns: + WanVAEDecoder with loaded weights + """ + # Check if file_dir is the model root or VAE subdirectory + file_path = epath.Path(file_dir).expanduser() + vae_path = file_path / "vae" + + if vae_path.exists(): + # Look in vae subdirectory + files = list(vae_path.glob("*.safetensors")) + else: + # Look in provided directory + files = list(file_path.glob("*.safetensors")) + + if not files: + raise ValueError(f"No safetensors found in {file_dir} or {file_dir}/vae") + + print(f"Found {len(files)} VAE safetensors file(s)") + + # Create VAE decoder structure + vae_decoder = nnx.eval_shape(lambda: vae_lib.WanVAEDecoder(rngs=nnx.Rngs(params=0))) + graph_def, abs_state = nnx.split(vae_decoder) + state_dict = abs_state.to_pure_dict() + + # Setup sharding if mesh provided + sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None + + key_mapping = _get_vae_key_mapping() + conversion_errors = [] + loaded_keys = [] + skipped_keys = [] + + for f in files: + print(f"Loading VAE weights from {f.name}...") + with safetensors.safe_open(f, framework="numpy") as sf: + for torch_key in sf.keys(): + tensor = sf.get_tensor(torch_key) + + jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key) + + if jax_key is None: + skipped_keys.append(torch_key) + # print(f"{torch_key} is not mapped") + continue + + keys = [_stoi(k) for k in jax_key.split(".")] + try: + _assign_weights(keys, tensor, state_dict, torch_key, transform, sharding) + loaded_keys.append(torch_key) + except Exception as e: + full_jax_key = ".".join([str(k) for k in keys]) + conversion_errors.append( + f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}" + ) + gc.collect() + + print(f"Loaded {len(loaded_keys)} VAE weight tensors") + print(f"Skipped {len(skipped_keys)} weight tensors") + + if conversion_errors: + print(f"\nWarning: {len(conversion_errors)} conversion errors occurred:") + for err in conversion_errors: # Show first 10 errors + print(f" {err}") + # if len(conversion_errors) > 10: + # print(f" ... and {len(conversion_errors) - 10} more") + + if len(loaded_keys) == 0: + raise ValueError("No VAE weights were loaded! Check the checkpoint structure and key mapping.") + + gc.collect() + return nnx.merge(graph_def, state_dict) + + +def _get_t5_key_mapping(): + """Define mapping from HuggingFace UMT5 keys to JAX UMT5 keys.""" + + class Transform(Enum): + """Transformations for UMT5 parameters""" + + NONE = None + TRANSPOSE = ((1, 0), None) # For linear layers: (out, in) -> (in, out) + + # UMT5/UMT5 uses standard HuggingFace naming + mapping = { + # Shared token embeddings + r"shared\.weight": ("encoder.token_embedding.embedding", Transform.NONE), + r"encoder\.embed_tokens\.weight": ("encoder.token_embedding.embedding", Transform.NONE), + # Encoder blocks - Self attention + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.q\.weight": ( + r"encoder.blocks.\1.attn.q.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.k\.weight": ( + r"encoder.blocks.\1.attn.k.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.v\.weight": ( + r"encoder.blocks.\1.attn.v.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.o\.weight": ( + r"encoder.blocks.\1.attn.o.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.SelfAttention\.relative_attention_bias\.weight": ( + r"encoder.blocks.\1.pos_embedding.embedding.embedding", + Transform.NONE, + ), + r"encoder\.block\.([0-9]+)\.layer\.0\.layer_norm\.weight": (r"encoder.blocks.\1.norm1.weight", Transform.NONE), + # Encoder blocks - Feed forward + r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wi_0\.weight": ( + r"encoder.blocks.\1.ffn.gate.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wi_1\.weight": ( + r"encoder.blocks.\1.ffn.fc1.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.1\.DenseReluDense\.wo\.weight": ( + r"encoder.blocks.\1.ffn.fc2.kernel", + Transform.TRANSPOSE, + ), + r"encoder\.block\.([0-9]+)\.layer\.1\.layer_norm\.weight": (r"encoder.blocks.\1.norm2.weight", Transform.NONE), + # Final layer norm + r"encoder\.final_layer_norm\.weight": ("encoder.norm.weight", Transform.NONE), + } + + return mapping + + +def create_t5_encoder_from_safe_tensors( + file_dir: str, + mesh: jax.sharding.Mesh | None = None, + is_sf: bool = True, + config: t5_lib.T5Config | None = None, +) -> t5_lib.T5EncoderModel: + """ + Load UMT5 encoder from safetensors checkpoint. + + Args: + file_dir: Directory containing .safetensors files or path to text_encoder directory + mesh: Optional JAX mesh for sharding + is_sf: Whether to load from safetensors (True) or PyTorch checkpoint (False) + config: T5Config to use. If None, defaults to UMT5-XXL + + Returns: + T5EncoderModel with loaded weights + """ + from bonsai.models.wan2 import umt5 + + # Use provided config or default to UMT5-XXL + if config is None: + config = umt5.T5Config.umt5_xxl() + + t5_encoder = nnx.eval_shape(lambda: umt5.T5EncoderModel(config, rngs=nnx.Rngs(params=0, dropout=0))) + graph_def, abs_state = nnx.split(t5_encoder) + state_dict = abs_state.to_pure_dict() + + sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None + + key_mapping = _get_t5_key_mapping() + conversion_errors = [] + loaded_keys = [] + skipped_keys = [] + + # Check if file_dir is the model root or text_encoder subdirectory + file_path = epath.Path(file_dir).expanduser() + text_encoder_path = file_path / "text_encoder" + + def load_pytorch_weights(file_dir): + from transformers import UMT5ForConditionalGeneration + + model = UMT5ForConditionalGeneration.from_pretrained(file_dir) + encoder_state = {k: v for k, v in model.state_dict().items() if k.startswith("encoder.")} + return encoder_state + + if is_sf: + if text_encoder_path.exists(): + files = sorted(list(text_encoder_path.glob("model-*.safetensors"))) + else: + files = sorted(list(file_path.glob("*.safetensors"))) + if not files: + raise ValueError(f"No safetensors found in {file_dir} or {file_dir}/text_encoder") + print(f"Found {len(files)} UMT5 encoder safetensors file(s)") + + for f in files: + print(f"Loading UMT5 weights from {f.name}...") + with safetensors.safe_open(f, framework="numpy") as sf: + for torch_key in sf.keys(): + tensor = sf.get_tensor(torch_key) + + jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key) + + if jax_key is None: + # Skip keys not in our mapping + skipped_keys.append(torch_key) + # print(f"{torch_key} is not mapped") + continue + + keys = [_stoi(k) for k in jax_key.split(".")] + try: + _assign_weights(keys, tensor, state_dict, torch_key, transform, sharding) + loaded_keys.append(torch_key) + except Exception as e: + full_jax_key = ".".join([str(k) for k in keys]) + conversion_errors.append( + f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}" + ) + gc.collect() + else: + print(f"Loading UMT5 weights from PyTorch checkpoint in {file_dir}...") + pt_state = load_pytorch_weights(file_dir) + for torch_key, tensor in pt_state.items(): + jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key) + + if jax_key is None: + # Skip keys not in our mapping + skipped_keys.append(torch_key) + # print(f"{torch_key} is not mapped") + continue + + keys = [_stoi(k) for k in jax_key.split(".")] + try: + _assign_weights(keys, tensor.numpy(), state_dict, torch_key, transform, sharding) + loaded_keys.append(torch_key) + except Exception as e: + full_jax_key = ".".join([str(k) for k in keys]) + conversion_errors.append(f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}") + gc.collect() + + print(f"Loaded {len(loaded_keys)} UMT5 weight tensors") + print(f"Skipped {len(skipped_keys)} weight tensors") + + if conversion_errors: + print(f"\nWarning: {len(conversion_errors)} conversion errors occurred:") + for err in conversion_errors: # Show first 10 errors + print(f" {err}") + # if len(conversion_errors) > 10: + # print(f" ... and {len(conversion_errors) - 10} more") + + if len(loaded_keys) == 0: + raise ValueError("No UMT5 weights were loaded! Check the checkpoint structure and key mapping.") + + gc.collect() + return nnx.merge(graph_def, state_dict) + + +__all__ = [ + "create_model_from_safe_tensors", + "create_t5_encoder_from_safe_tensors", + "create_vae_decoder_from_safe_tensors", +] diff --git a/bonsai/models/wan2/pipeline_wan.py b/bonsai/models/wan2/pipeline_wan.py new file mode 100644 index 00000000..78bab7f5 --- /dev/null +++ b/bonsai/models/wan2/pipeline_wan.py @@ -0,0 +1,245 @@ +"""Example script for running Wan2.1-T2V-1.3B text-to-video generation.""" + +import argparse +import time +import traceback +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import nnx +from huggingface_hub import snapshot_download +from jaxtyping import Array +from transformers import AutoTokenizer + +from bonsai.models.wan2 import params, transformer_wan, umt5, vae_wan +from bonsai.models.wan2.transformer_wan import TransformerWanModelConfig, Wan2DiT +from bonsai.models.wan2.unipc_multistep_scheduler import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState + +jax.config.update("jax_debug_nans", True) + + +def get_t5_text_embeddings( + prompt: str, + tokenizer: AutoTokenizer = None, + text_encoder: umt5.T5EncoderModel = None, + max_length: int = 512, + dtype: jnp.dtype = jnp.float32, +): + try: + inputs = tokenizer( + prompt, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="np", # Return numpy arrays + ) + input_ids = jnp.array(inputs["input_ids"]) + attention_mask = jnp.array(inputs["attention_mask"]) + seq_lens = jnp.sum(attention_mask, axis=1).astype(jnp.int32) + print(f"seq_lens: {seq_lens}") + embeddings = text_encoder(input_ids, attention_mask, deterministic=True) + prompt_embeds = jnp.asarray(embeddings, dtype=dtype) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = jnp.stack( + [ + jnp.concatenate( + [u, jnp.zeros((max_length - u.shape[0], u.shape[1]), dtype=u.dtype)], + axis=0, + ) + for u in prompt_embeds + ], + axis=0, + ) + return prompt_embeds + except Exception as e: + print(f"Error in text encoding: {e}") + traceback.print_exc() + + +def decode_video_latents(latents: jax.Array, vae_decoder: vae_wan.WanVAEDecoder): + """ + Decode video latents to RGB frames using Wan-VAE. + + Args: + latents: [B, T, H, W, C] video latents + vae_decoder: Optional WanVAEDecoder instance. If None, returns dummy video. + + Returns: + video: [B, T, H_out, W_out, 3] RGB video (uint8) + """ + # Decode using VAE + video = vae_wan.decode_latents_to_video(vae_decoder, latents, normalize=True) + return video + + +def generate_video( + model: Wan2DiT, + latents: Array, + text_embeds: Array, + negative_embeds: Array, + num_steps: int = 50, + guidance_scale: float = 5.5, + scheduler: Optional[FlaxUniPCMultistepScheduler] = None, + scheduler_state: Optional[UniPCMultistepSchedulerState] = None, +) -> Array: + """ + Generate video from text embeddings using the diffusion model. + + Args: + model: Wan2DiT model + text_embeds: [B, seq_len, text_dim] text embeddings from UMT5 + num_frames: Number of frames to generate + latent_size: Spatial size of latents + num_steps: Number of denoising steps + guidance_scale: Classifier-free guidance scale (5-6 recommended) + + Returns: + latents: [B, T, H, W, C] generated video latents + """ + b = text_embeds.shape[0] + + # Initialize random noise + scheduler_state = scheduler.set_timesteps( + scheduler_state, num_inference_steps=num_steps, shape=latents.transpose(0, 4, 1, 2, 3).shape + ) + + for t_idx in range(num_steps): + # Scheduler needs scalar timestep, model needs batched timestep + t_scalar = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[t_idx] + t_batch = jnp.full((b,), t_scalar, dtype=jnp.int32) + + # Classifier-free guidance + if guidance_scale != 1.0: + noise_pred_cond = model.forward(latents, text_embeds, t_batch, deterministic=True) + noise_pred_uncond = model.forward(latents, negative_embeds, t_batch, deterministic=True) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = model.forward(latents, text_embeds, t_batch, deterministic=True) + + latents, scheduler_state = scheduler.step( + scheduler_state, noise_pred.transpose(0, 4, 1, 2, 3), t_scalar, latents.transpose(0, 4, 1, 2, 3) + ) + latents = latents.transpose(0, 2, 3, 4, 1) # back to channel-last + + return latents + + +def run_model(prompt: Optional[str] = None, neg_prompt: Optional[str] = None) -> jax.Array: + print("=" * 60) + print("Wan2.1-T2V-1.3B Text-to-Video Generation Demo") + print("=" * 60) + + model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + config = TransformerWanModelConfig() + + # For sharding (multi-GPU), uncomment: + # from jax.sharding import AxisType + # mesh = jax.make_mesh((2, 2), ("fsdp", "tp"), axis_types=(AxisType.Explicit, AxisType.Explicit)) + # jax.set_mesh(mesh) + + scheduler = FlaxUniPCMultistepScheduler( + num_train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + solver_order=2, # Order 2 for guided sampling + prediction_type="flow_prediction", + use_flow_sigmas=True, # Enable flow-based sigma schedule + flow_shift=3.0, # 5.0 for 720P, 3.0 for 480P + timestep_spacing="linspace", + predict_x0=True, + solver_type="bh2", + lower_order_final=True, + dtype=jnp.float32, + ) + scheduler_state = scheduler.create_state() + + if prompt is not None: + prompts = [prompt] + else: + prompts = [ + "A curious racoon", + ] + + print(f"\nPrompt: {prompts[0]}") + print(f"Model: Wan2.1-T2V-1.3B ({config.num_layers} layers, {config.hidden_dim} dim)") + print(f"Video: {config.num_frames} frames @ 480p") + + tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + umt5_encoder = params.create_t5_encoder_from_safe_tensors(model_ckpt_path, mesh=None) + + print("\n[1/4] Encoding text with UMT5...") + text_embeds = get_t5_text_embeddings(prompts[0], tokenizer, umt5_encoder, max_length=config.max_text_len) + if neg_prompt is not None: + negative_prompts = [neg_prompt] + else: + negative_prompts = ["blurry"] + negative_embeds = get_t5_text_embeddings( + negative_prompts[0], tokenizer, umt5_encoder, max_length=config.max_text_len + ) + + print("\n[2/5] Loading Diffusion Transformer weights...") + model = params.create_model_from_safe_tensors(model_ckpt_path, config, mesh=None) + print("\n[2.5/5] Loading VAE decoder...") + vae_decoder = params.create_vae_decoder_from_safe_tensors(model_ckpt_path, mesh=None) + print("Model loaded successfully") + + print("\n[3/4] Generating video latents...") + print(f"Using {config.num_inference_steps} diffusion steps") + print(f"Guidance scale: {config.guidance_scale}") + + key = jax.random.PRNGKey(42) + start_time = time.time() + + latents = jax.random.normal( + key, (1, config.num_frames, config.latent_size[0], config.latent_size[1], config.latent_input_dim) + ) + + latents = generate_video( + model=model, + latents=latents, + text_embeds=text_embeds, + negative_embeds=negative_embeds, + num_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + scheduler=scheduler, + scheduler_state=scheduler_state, + ) + + generation_time = time.time() - start_time + print(f"βœ“ Generated latents in {generation_time:.2f}s") + print(f"Latents shape: {latents.shape}") + print(latents[0, 1:, :, 25:, :].mean()) + print(f"Has NaN: {jnp.isnan(latents).any()}") + print(f"Has Inf: {jnp.isinf(latents).any()}") + + print("\n[4/5] Decoding latents to video...") + video = decode_video_latents(latents, vae_decoder) + generation_time = time.time() - start_time + print(f"Video shape: {video.shape}") + print(video[0, 1:, :, 235:, :].mean()) + + print("\n" + "=" * 60) + print("βœ“ Generation Complete!") + print("=" * 60) + print(f"Total time: {generation_time:.2f}s") + print(f"FPS: {config.num_frames / generation_time:.2f}") + vae_wan.save_video(video, "generated_video.mp4") + print("Video saved to generated_video.mp4") + + return video + + +def main(): + parser = argparse.ArgumentParser(description="Wan2.1-T2V-1.3B Text-to-Video Generation Demo") + parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation") + parser.add_argument("--neg_prompt", type=str, default=None, help="Negative text prompt for video generation") + args = parser.parse_args() + run_model(args.prompt, args.neg_prompt) + + +if __name__ == "__main__": + main() + +__all__ = ["run_model"] diff --git a/bonsai/models/wan2/scheduling_utils.py b/bonsai/models/wan2/scheduling_utils.py new file mode 100644 index 00000000..da722600 --- /dev/null +++ b/bonsai/models/wan2/scheduling_utils.py @@ -0,0 +1,318 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import math +import os +from dataclasses import dataclass +from enum import Enum +from typing import ClassVar, Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +# NOTE: We make this type an enum because it simplifies usage in docs and prevents +# circular imports when used for `_compatibles` within the schedulers module. +# When it's used as a type in pipelines, it really is a Union because the actual +# scheduler instance is passed in. +class FlaxKarrasDiffusionSchedulers(Enum): + FlaxDDIMScheduler = 1 + FlaxDDPMScheduler = 2 + FlaxPNDMScheduler = 3 + FlaxLMSDiscreteScheduler = 4 + FlaxDPMSolverMultistepScheduler = 5 + + +@dataclass +class FlaxSchedulerOutput: + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: jnp.ndarray + + +class FlaxSchedulerMixin: + """ + Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). + """ + + config_name = SCHEDULER_CONFIG_NAME + ignore_for_config: ClassVar = ["dtype"] + _compatibles: ClassVar = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], + e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + **kwargs, + ) + scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) + + if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): + state = scheduler.create_state() + + if return_unused_kwargs: + return scheduler, state, unused_kwargs + + return scheduler, state + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~FlaxSchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__, *cls._compatibles])) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes + + +def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: + assert len(shape) >= x.ndim + return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) + + +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=dtype) + + +def rescale_betas_zero_snr(betas): + """ + Rescales betas to have a zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + """ + + alphas = 1.0 - betas + alphas_cumprod = jnp.cumprod(alphas, axis=0) + alphas_bar_sqrt = jnp.sqrt(alphas_cumprod) + + # Store old values. + alphas_bar_sqrt_0 = jnp.copy(alphas_bar_sqrt[0]) + alphas_bar_sqrt_T = jnp.copy(alphas_bar_sqrt[-1]) + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = jnp.concatenate([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +@flax.struct.dataclass +class CommonSchedulerState: + alphas: jnp.ndarray + betas: jnp.ndarray + alphas_cumprod: jnp.ndarray + + @classmethod + def create(cls, scheduler): + config = scheduler.config + + if config.trained_betas is not None: + betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype) + elif config.beta_schedule == "linear": + betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype) + elif config.beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + betas = ( + jnp.linspace( + config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype + ) + ** 2 + ) + elif config.beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype) + else: + raise NotImplementedError( + f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}" + ) + + if not config.rescale_zero_terminal_snr: + betas = rescale_betas_zero_snr(betas) + + alphas = 1.0 - betas + + alphas_cumprod = jnp.cumprod(alphas, axis=0) + + return cls( + alphas=alphas, + betas=betas, + alphas_cumprod=alphas_cumprod, + ) + + +def get_sqrt_alpha_prod( + state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray +): + alphas_cumprod = state.alphas_cumprod + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) + + return sqrt_alpha_prod, sqrt_one_minus_alpha_prod + + +def add_noise_common( + state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray +): + sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps) + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + +def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray): + sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps) + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/bonsai/models/wan2/tests/__init__.py b/bonsai/models/wan2/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bonsai/models/wan2/tests/test_transformer_output.py b/bonsai/models/wan2/tests/test_transformer_output.py new file mode 100644 index 00000000..590e5a96 --- /dev/null +++ b/bonsai/models/wan2/tests/test_transformer_output.py @@ -0,0 +1,792 @@ +"""Test output correctness by comparing JAX implementation with HuggingFace reference.""" + +import jax +import jax.numpy as jnp +import numpy as np +from huggingface_hub import snapshot_download +from bonsai.models.wan2 import params, transformer_wan +import torch +from diffusers import AutoModel +from jax.lax import Precision +from collections import OrderedDict + +def check_weight_loading(jax_model, torch_model): + + text_proj_torch = torch_model.condition_embedder.text_embedder.linear_1.weight.detach().cpu().float().numpy().T + text_proj_jax = np.array(jax_model.text_proj.layers[0].kernel.value) + print("Text projection weights:") + print(f" Shapes: torch={text_proj_torch.shape}, jax={text_proj_jax.shape}") + print(f" Max diff: {np.abs(text_proj_torch - text_proj_jax).max():.2e}") + print(f" Mean diff: {np.abs(text_proj_torch - text_proj_jax).mean():.2e}") + # torch :(out, in, t, h, w) + torch_emb = torch_model.patch_embedding.weight.detach().cpu().float().numpy() + # jax: (t, h, w, in, out) + jax_emb = np.array(jax_model.patch_embed.kernel.value).transpose(4,3,0,1,2) + + print("Embedding weights:") + print(f" Shapes: torch={torch_emb.shape}, jax={jax_emb.shape}") + print(f" Max diff: {np.abs(torch_emb - jax_emb).max():.2e}") + print(f" Mean diff: {np.abs(torch_emb - jax_emb).mean():.2e}") + + # check fused kv projection weights in first block + torch_k_weight = torch_model.blocks[0].attn2.to_k.weight.detach().cpu().float().numpy().T + torch_v_weight = torch_model.blocks[0].attn2.to_v.weight.detach().cpu().float().numpy().T + torch_kv = np.concatenate([torch_k_weight, torch_v_weight], axis=1) + + jax_kv_weight = np.array(jax_model.blocks[0].cross_attn.kv_proj.kernel.value) + print("First block cross-attention KV projection weights:") + print(f" Shapes: torch={torch_kv.shape}, jax={jax_kv_weight.shape}") + print(f" Max diff: {np.abs(torch_kv - jax_kv_weight).max():.2e}") + print(f" Mean diff: {np.abs(torch_kv - jax_kv_weight).mean():.2e}") + +def compare_outputs(jax_output: jax.Array, torch_output, name: str, rtol: float = 1e-2, atol: float = 1e-4): + + print(f"\n{'=' * 80}") + print(f"Comparing: {name}") + print(f"{'=' * 80}") + + print(f"before convert torch: {torch_output.dtype}, jax: {jax_output.dtype}") + if torch_output.dtype == torch.bfloat16: + torch_output = torch_output.float() + if jax_output.dtype == jnp.bfloat16: + jax_output = jax_output.astype(jnp.float32) + + if isinstance(torch_output, torch.Tensor): + torch_np = torch_output.detach().cpu().numpy() + else: + torch_np = np.array(torch_output) + + jax_np = np.array(jax_output) + + print(f"JAX shape: {jax_np.shape}") + print(f"Torch shape: {torch_np.shape}") + print(f"JAX dtype: {jax_np.dtype}") + print(f"Torch dtype: {torch_np.dtype}") + + if jax_np.shape != torch_np.shape: + print("Shape mismatch!") + return False + + abs_diff = np.abs(jax_np - torch_np) + rel_diff = abs_diff / (np.abs(torch_np) + 1e-10) + + max_abs_diff = np.max(abs_diff) + max_rel_diff = np.max(rel_diff) + mean_abs_diff = np.mean(abs_diff) + mean_rel_diff = np.mean(rel_diff) + + print("\nStatistics:") + print(f" Max absolute difference: {max_abs_diff:.2e}") + print(f" Max relative difference: {max_rel_diff:.2e}") + print(f" Mean absolute difference: {mean_abs_diff:.2e}") + print(f" Mean relative difference: {mean_rel_diff:.2e}") + + print(f"\nJAX output range: [{np.min(jax_np):.4f}, {np.max(jax_np):.4f}]") + print(f"Torch output range: [{np.min(torch_np):.4f}, {np.max(torch_np):.4f}]") + + close = np.allclose(jax_np, torch_np, rtol=rtol, atol=atol) + + if close: + print(f"\nβœ… Outputs match within tolerance (rtol={rtol}, atol={atol})") + else: + print(f"\n❌ Outputs do NOT match (rtol={rtol}, atol={atol})") + # Show some mismatched locations + mismatch_mask = ~np.isclose(jax_np, torch_np, rtol=rtol, atol=atol) + n_mismatches = np.sum(mismatch_mask) + print(f" Number of mismatches: {n_mismatches} / {jax_np.size} ({100 * n_mismatches / jax_np.size:.2f}%)") + + return close + +# e2e test +def test_dit_output(): + print("\n" + "=" * 80) + print("TEST 2: DiT") + print("=" * 80) + + model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + config = transformer_wan.TransformerWanModelConfig + + print("\n[1/2] Loading transformer") + transformer = AutoModel.from_pretrained(model_ckpt_path, subfolder="transformer", torch_dtype=torch.bfloat16) + + jax_dit = params.create_model_from_safe_tensors(model_ckpt_path, config,mesh=None) + + batch_size = 1 + num_channels = 16 # in_channels + num_frames = 9 + height = 30 + width = 30 + text_seq_len = 128 + text_dim = 4096 # UMT5 hidden dimension + + # Create dummy inputs + hidden_states = torch.randn( + batch_size, num_channels, num_frames, height, width, + dtype=torch.float32 + ) + hidden_states_jax = jnp.array(np.transpose(hidden_states.numpy(), (0, 2, 3, 4, 1))) + + timestep = torch.randint( + 0, 1000, + (batch_size,), + dtype=torch.long + ) + timestep_jax = jnp.array(timestep.numpy()) + + encoder_hidden_states = torch.randn( + batch_size, text_seq_len, text_dim, + dtype=torch.float32 + ) + encoder_hidden_states_jax = jnp.array(encoder_hidden_states.numpy()) + + print("\n[2/2] Running forward pass") + with torch.no_grad(): + output = transformer( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=None, # Only for I2V models + return_dict=True, + attention_kwargs=None, + ) + pred_noise = jax_dit.forward(hidden_states_jax, encoder_hidden_states_jax, timestep_jax, deterministic=True) + expected_output = np.transpose(output.sample.numpy(), (0, 2, 3, 4, 1)) + + # Compare final output + return compare_outputs(pred_noise, expected_output, "Final DiT Output", rtol=1e-3, atol=1e-4) + +def test_dit(): + print("\n" + "=" * 80) + print("TEST 2: DiT") + print("=" * 80) + + model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + config = transformer_wan.TransformerWanModelConfig + + print("\n[1/2] Loading transformer") + transformer = AutoModel.from_pretrained(model_ckpt_path, subfolder="transformer", torch_dtype=torch.bfloat16) + + jax_dit = params.create_model_from_safe_tensors(model_ckpt_path, config,mesh=None) + print("transformer loaded:", transformer, transformer.config) + + check_weight_loading(jax_dit, transformer) + + batch_size = 1 + num_channels = 16 # in_channels + num_frames = 9 + height = 30 + width = 30 + text_seq_len = 128 + text_dim = 4096 # UMT5 hidden dimension + + debugger = WanTransformerDebugger(transformer) + debugger.register_hooks() + debugger_attn = WanAttentionDebugger(transformer) + debugger_attn.register_attention_hooks(block_indices=[0]) + + # Create dummy inputs + hidden_states = torch.randn( + batch_size, num_channels, num_frames, height, width, + dtype=torch.float32 + ) + # jax channels last + hidden_states_jax = jnp.array(np.transpose(hidden_states.numpy(), (0, 2, 3, 4, 1))).astype(jnp.bfloat16) + timestep = torch.randint( + 0, 1000, + (batch_size,), + dtype=torch.long + ) + timestep_jax = jnp.array(timestep.numpy()) + encoder_hidden_states = torch.randn( + batch_size, text_seq_len, text_dim, + dtype=torch.float32 + ) + encoder_hidden_states_jax = jnp.array(encoder_hidden_states.numpy()).astype(jnp.bfloat16) + + print("\n[2/2] Running forward pass") + with torch.no_grad(): + output = transformer( + hidden_states=hidden_states.to(dtype=torch.bfloat16), + timestep=timestep, + encoder_hidden_states=encoder_hidden_states.to(dtype=torch.bfloat16), + encoder_hidden_states_image=None, # Only for I2V models + return_dict=True, + attention_kwargs=None, + ) + + # 5. Get intermediate outputs + intermediate_outputs = debugger.get_outputs() + states = debugger_attn.get_attention_states() + + print("=" * 80) + print("INTERMEDIATE OUTPUTS") + print("=" * 80) + + for name, tensor in states.items(): + print(f"{name:50s}: {tuple(tensor.shape)}") + + # Restore original processors + debugger_attn.restore_processors() + + # Manual forward pass with intermediate comparisons + print("\n" + "=" * 80) + print("STEP-BY-STEP FORWARD PASS WITH COMPARISONS") + print("=" * 80) + + # 1. Text projection + text_embeds_jax = jax_dit.text_proj(encoder_hidden_states_jax) + text_embeds_torch = intermediate_outputs['condition_encoder_hidden_states'] + compare_outputs(text_embeds_jax, text_embeds_torch, "Text Projection", rtol=1e-3, atol=1e-4) + + # 2. Patch embedding + x_jax = jax_dit.patch_embed(hidden_states_jax) + # PyTorch is BCTHW, need to convert to BTHWC for comparison + patch_embed_torch = intermediate_outputs['patch_embed_output'] + patch_embed_torch_channels_last = np.transpose(patch_embed_torch.float().numpy(), (0, 2, 3, 4, 1)) + compare_outputs(x_jax, patch_embed_torch_channels_last, "Patch Embedding", rtol=1e-3, atol=1e-4) + + # Reshape to sequence + b, t_out, h_out, w_out, d = x_jax.shape + x_jax = x_jax.reshape(b, t_out * h_out * w_out, d) + grid_sizes = (t_out, h_out, w_out) + + # 3. RoPE frequencies + max_seq = max(grid_sizes) + rope_freqs = tuple( + jax.lax.stop_gradient(arr) + for arr in transformer_wan.precompute_freqs_cis_3d( + dim=jax_dit.cfg.head_dim, + theta=jax_dit.rope_theta, + max_seq_len=max_seq + ) + ) + + # Build full RoPE frequency grid for comparison + freqs_t, freqs_h, freqs_w = rope_freqs + f, h, w = grid_sizes + head_dim = jax_dit.cfg.head_dim + dim_base = head_dim // 6 + dim_t, dim_h, dim_w = head_dim - 4 * dim_base, 2 * dim_base, 2 * dim_base + + freqs_grid = jnp.concatenate([ + jnp.broadcast_to(freqs_t[:f, None, None, :, :], (f, h, w, dim_t // 2, 2)), + jnp.broadcast_to(freqs_h[None, :h, None, :, :], (f, h, w, dim_h // 2, 2)), + jnp.broadcast_to(freqs_w[None, None, :w, :, :], (f, h, w, dim_w // 2, 2)), + ], axis=3).reshape(t_out * h_out * w_out, head_dim // 2, 2) + + rope_freqs_cos_jax = freqs_grid[..., 0] + rope_freqs_cos_jax = jnp.stack([rope_freqs_cos_jax, rope_freqs_cos_jax], axis=-1).reshape(1, -1, 1, head_dim) + + # PyTorch RoPE freqs are in BCHW format, convert to sequence format + rope_freqs_cos_torch = intermediate_outputs['rope_freqs_cos'] + compare_outputs(rope_freqs_cos_jax, rope_freqs_cos_torch, "RoPE Freqs Cos", rtol=1e-5, atol=1e-6) + + # 4. Time embeddings + time_emb_jax, time_proj_jax = jax_dit.time_embed(timestep_jax) + time_emb_torch = intermediate_outputs['condition_temb'] + time_proj_torch = intermediate_outputs['condition_timestep_proj'] + compare_outputs(time_emb_jax, time_emb_torch, "Time Embedding", rtol=1e-3, atol=1e-4) + compare_outputs(time_proj_jax, time_proj_torch, "Time Projection", rtol=1e-3, atol=1e-4) + + # 5. Process through transformer blocks with detailed attention comparison + for i, block in enumerate(jax_dit.blocks): + if i==0: + print(f"\n{'='*80}") + print(f"BLOCK {i} - DETAILED COMPARISON") + print(f"{'='*80}") + + # Get modulation parameters + b_size = time_proj_jax.shape[0] + d = jax_dit.cfg.hidden_dim + reshaped_time_emb = time_proj_jax.reshape(b_size, 6, d) + modulation = reshaped_time_emb + block.scale_shift_table.value + modulation = modulation.reshape(b_size, -1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(modulation, 6, axis=-1) + + # Self-attention with detailed steps + print(f"\n--- Self-Attention ---") + norm_x = block.norm1(x_jax) + norm_x_modulated = norm_x * (1 + scale_msa[:, None, :]) + shift_msa[:, None, :] + + attn_out = block.self_attn(norm_x_modulated, deterministic=True, rope_state=(rope_freqs, grid_sizes)) + + # # Q, K, V projections + num_heads = block.self_attn.num_heads + head_dim = block.self_attn.head_dim + b_size, n = norm_x_modulated.shape[:2] + + # Compare with PyTorch + attn1_output_torch = intermediate_outputs[f'block_{i}_attn1_output'] + compare_outputs(attn_out, attn1_output_torch, f"Block {i} Attn1 Output", rtol=1e-2, atol=1e-3) + + # Apply gate and residual + x_jax = x_jax + gate_msa[:, None, :] * attn_out + + # Cross-attention + print(f"\n--- Cross-Attention ---") + norm_x = block.norm2(x_jax) + compare_outputs(norm_x, intermediate_outputs[f'block_{i}_norm2_output'], f"Block {i} Norm2 Output", rtol=1e-3, atol=1e-4) + b, n, m = norm_x.shape[0], norm_x.shape[1], text_embeds_jax.shape[1] + q_norm = block.cross_attn.q_norm(block.cross_attn.q_proj(norm_x)) + compare_outputs(q_norm, intermediate_outputs[f'block_{i}_attn2_query_normed'], f"Block {i} Attn2 Q after Norm", rtol=1e-5, atol=1e-6) + kv = block.cross_attn.kv_proj(text_embeds_jax) + k, v = jnp.split(kv, 2, axis=-1) + k_norm = block.cross_attn.k_norm(k) + compare_outputs(k_norm, intermediate_outputs[f'block_{i}_attn2_key_normed'], f"Block {i} Attn2 K after Norm", rtol=1e-5, atol=1e-6) + + cross_out = block.cross_attn(norm_x, text_embeds_jax, deterministic=True) + + + attn2_output_torch = intermediate_outputs[f'block_{i}_attn2_output'] + compare_outputs(cross_out, attn2_output_torch, f"Block {i} Attn2 Output", rtol=1e-2, atol=1e-3) + + x_jax = x_jax + cross_out + + # MLP + print(f"\n--- MLP ---") + norm_x = block.norm3(x_jax) + norm_x_modulated = norm_x * (1 + scale_mlp[:, None, :]) + shift_mlp[:, None, :] + mlp_out = block.mlp(norm_x_modulated) + + ffn_output_torch = intermediate_outputs[f'block_{i}_ffn_output'] + compare_outputs(mlp_out, ffn_output_torch, f"Block {i} FFN Output", rtol=1e-2, atol=1e-3) + + x_jax = x_jax + gate_mlp[:, None, :] * mlp_out + + # Compare final block output + block_output_torch = intermediate_outputs[f'block_{i}_output'] + compare_outputs(x_jax, block_output_torch, f"Block {i} Final Output", rtol=1e-2, atol=1e-3) + + if i > 0: # Only compare first block in detail + x_jax = block(x_jax, text_embeds_jax, time_proj_jax, deterministic=True, rope_state=(rope_freqs, grid_sizes)) + compare_outputs(x_jax, intermediate_outputs[f'block_{i}_output'], f"Block {i} Output", rtol=1e-2, atol=1e-3) + + # 6. Final layer + jax_dit_output = jax_dit.final_layer(x_jax, time_emb_jax) + + compare_outputs(jax_dit_output, intermediate_outputs['proj_out_output'], "Final Projection Output", rtol=1e-3, atol=1e-4) + + # Reshape to video format + jax_dit_output = jax_dit.unpatchify(jax_dit_output, grid_sizes) + + # 4. Verify output shape + expected_shape = (batch_size, num_channels, num_frames, height, width) + assert output.sample.shape == expected_shape + + # change to channels last for comparison + expected_output = np.transpose(output.sample.float().numpy(), (0, 2, 3, 4, 1)) + + debugger.remove_hooks() + # Compare final output + return compare_outputs(jax_dit_output, expected_output, "Final DiT Output", rtol=1e-3, atol=1e-4) +class WanTransformerDebugger: + """Helper class to extract intermediate outputs from WanTransformer3DModel""" + + def __init__(self, model): + self.model = model + self.intermediate_outputs = OrderedDict() + self.hooks = [] + + def register_hooks(self): + """Register forward hooks to capture intermediate outputs""" + + # Hook for patch embedding + def patch_embed_hook(module, input, output): + self.intermediate_outputs['patch_embed_output'] = output.detach().cpu() + + # Hook for rotary embeddings + def rope_hook(module, input, output): + freqs_cos, freqs_sin = output + self.intermediate_outputs['rope_freqs_cos'] = freqs_cos.detach().cpu() + self.intermediate_outputs['rope_freqs_sin'] = freqs_sin.detach().cpu() + + # Hook for condition embedder + def condition_embedder_hook(module, input, output): + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = output + self.intermediate_outputs['condition_temb'] = temb.detach().cpu() + self.intermediate_outputs['condition_timestep_proj'] = timestep_proj.detach().cpu() + self.intermediate_outputs['condition_encoder_hidden_states'] = encoder_hidden_states.detach().cpu() + if encoder_hidden_states_image is not None: + self.intermediate_outputs['condition_encoder_hidden_states_image'] = encoder_hidden_states_image.detach().cpu() + + # Hook for each transformer block + for i, block in enumerate(self.model.blocks): + def make_block_hook(block_idx): + def block_hook(module, input, output): + self.intermediate_outputs[f'block_{block_idx}_output'] = output.detach().cpu() + return block_hook + + # Hook for block output + handle = block.register_forward_hook(make_block_hook(i)) + self.hooks.append(handle) + + # Hook for self-attention in each block + def make_attn1_hook(block_idx): + def attn1_hook(module, input, output): + self.intermediate_outputs[f'block_{block_idx}_attn1_output'] = output.detach().cpu() + return attn1_hook + + handle = block.attn1.register_forward_hook(make_attn1_hook(i)) + self.hooks.append(handle) + + # Hook for cross-attention in each block + def make_attn2_hook(block_idx): + def attn2_hook(module, input, output): + self.intermediate_outputs[f'block_{block_idx}_attn2_output'] = output.detach().cpu() + return attn2_hook + + def make_hook(name): + def hook(module, input, output): + if isinstance(output, tuple): + self.intermediate_outputs[name] = output[0].detach().cpu() + else: + self.intermediate_outputs[name] = output.detach().cpu() + return hook + + handle = block.attn2.register_forward_hook(make_attn2_hook(i)) + self.hooks.append(handle) + handle = block.norm2.register_forward_hook(make_hook(f'block_{i}_norm2_output')) + self.hooks.append(handle) + + attn = block.attn2 + # 1. Hook Q projection + h = attn.to_q.register_forward_hook( + make_hook(f'block_{i}_attn2_query') + ) + self.hooks.append(h) + + # 2. Hook K projection + h = attn.to_k.register_forward_hook( + make_hook(f'block_{i}_attn2_key') + ) + self.hooks.append(h) + + # 3. Hook V projection + h = attn.to_v.register_forward_hook( + make_hook(f'block_{i}_attn2_value') + ) + self.hooks.append(h) + + # 4. Hook Q norm + h = attn.norm_q.register_forward_hook( + make_hook(f'block_{i}_attn2_query_normed') + ) + self.hooks.append(h) + + # 5. Hook K norm + h = attn.norm_k.register_forward_hook( + make_hook(f'block_{i}_attn2_key_normed') + ) + self.hooks.append(h) + + # 6. Hook output projection + h = attn.to_out[0].register_forward_hook( + make_hook(f'block_{i}_attn2_output') + ) + self.hooks.append(h) + h = attn.register_forward_hook(make_hook(f'block_{i}_attn2_attention')) + self.hooks.append(h) + + + # Hook for FFN in each block + def make_ffn_hook(block_idx): + def ffn_hook(module, input, output): + self.intermediate_outputs[f'block_{block_idx}_ffn_output'] = output.detach().cpu() + return ffn_hook + + handle = block.ffn.register_forward_hook(make_ffn_hook(i)) + self.hooks.append(handle) + + # Hook for patch embedding + handle = self.model.patch_embedding.register_forward_hook(patch_embed_hook) + self.hooks.append(handle) + + # Hook for rope + handle = self.model.rope.register_forward_hook(rope_hook) + self.hooks.append(handle) + + # Hook for condition embedder + handle = self.model.condition_embedder.register_forward_hook(condition_embedder_hook) + self.hooks.append(handle) + + # Hook for final norm + def norm_out_hook(module, input, output): + self.intermediate_outputs['norm_out_output'] = output.detach().cpu() + + handle = self.model.norm_out.register_forward_hook(norm_out_hook) + self.hooks.append(handle) + + # Hook for final projection + def proj_out_hook(module, input, output): + self.intermediate_outputs['proj_out_output'] = output.detach().cpu() + + handle = self.model.proj_out.register_forward_hook(proj_out_hook) + self.hooks.append(handle) + + def remove_hooks(self): + """Remove all registered hooks""" + for hook in self.hooks: + hook.remove() + self.hooks = [] + + def get_outputs(self): + """Get all captured intermediate outputs""" + return self.intermediate_outputs + + def clear_outputs(self): + """Clear stored outputs""" + self.intermediate_outputs = OrderedDict() +class WanAttentionDebugger: + """Capture internal attention states (Q, K, V, attention scores, etc.)""" + + def __init__(self, model): + self.model = model + self.attention_states = OrderedDict() + self.hooks = [] + self.original_processors = {} + + def register_attention_hooks(self, block_indices=None): + """ + Register hooks to capture attention internal states. + + Args: + block_indices: List of block indices to hook, or None for all blocks + """ + if block_indices is None: + block_indices = range(len(self.model.blocks)) + + for i in block_indices: + block = self.model.blocks[i] + + # Hook self-attention (attn1) + self._hook_attention_module(block.attn1, f'block_{i}_attn1') + + # Hook cross-attention (attn2) + self._hook_attention_module(block.attn2, f'block_{i}_attn2') + + def _hook_attention_module(self, attn_module, prefix): + """Hook a single attention module to capture Q, K, V, and attention outputs""" + + # Save original processor + self.original_processors[prefix] = attn_module.processor + + # Create custom processor that captures intermediates + original_processor = attn_module.processor + attention_states = self.attention_states + + class InstrumentedProcessor: + """Wrapper processor that captures intermediate values""" + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + rotary_emb=None, + **kwargs + ): + # Get encoder hidden states + encoder_hidden_states_img = None + if attn.add_k_proj is not None and encoder_hidden_states is not None: + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + # 1. Capture Q, K, V projections (before normalization) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + attention_states[f'{prefix}_query_raw'] = query.detach().cpu() + attention_states[f'{prefix}_key_raw'] = key.detach().cpu() + attention_states[f'{prefix}_value_raw'] = value.detach().cpu() + + # 2. Capture after normalization + query = attn.norm_q(query) + key = attn.norm_k(key) + + attention_states[f'{prefix}_query_normed'] = query.detach().cpu() + attention_states[f'{prefix}_key_normed'] = key.detach().cpu() + + # 3. Reshape to heads + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + attention_states[f'{prefix}_query_heads'] = query.detach().cpu() + attention_states[f'{prefix}_key_heads'] = key.detach().cpu() + attention_states[f'{prefix}_value_heads'] = value.detach().cpu() + + # 4. Capture after RoPE (if applied) + if rotary_emb is not None: + def apply_rotary_emb(hidden_states, freqs_cos, freqs_sin): + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + attention_states[f'{prefix}_query_rope'] = query.detach().cpu() + attention_states[f'{prefix}_key_rope'] = key.detach().cpu() + + # 5. Handle I2V additional K, V + hidden_states_img = None + if encoder_hidden_states_img is not None: + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = attn.norm_added_k(key_img) + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + attention_states[f'{prefix}_key_img'] = key_img.detach().cpu() + attention_states[f'{prefix}_value_img'] = value_img.detach().cpu() + + # Compute image attention (for I2V) + from diffusers.models.attention_dispatch import dispatch_attention_fn + hidden_states_img = dispatch_attention_fn( + query, key_img, value_img, + attn_mask=None, dropout_p=0.0, is_causal=False, + backend=original_processor._attention_backend, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + + attention_states[f'{prefix}_img_attention_output'] = hidden_states_img.detach().cpu() + + # 6. Compute main attention + from diffusers.models.attention_dispatch import dispatch_attention_fn + + # Note: We can't easily capture attention weights with dispatch_attention_fn + # because it uses optimized kernels (flash attention, etc.) + # For debugging attention weights, we'd need to use manual computation + + hidden_states = dispatch_attention_fn( + query, key, value, + attn_mask=attention_mask, dropout_p=0.0, is_causal=False, + backend=original_processor._attention_backend, + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + attention_states[f'{prefix}_attention_output'] = hidden_states.detach().cpu() + + # 7. Combine with image attention if present + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + attention_states[f'{prefix}_combined_output'] = hidden_states.detach().cpu() + + # 8. Output projection + hidden_states = attn.to_out[0](hidden_states) + attention_states[f'{prefix}_output_proj'] = hidden_states.detach().cpu() + + hidden_states = attn.to_out[1](hidden_states) # Dropout + attention_states[f'{prefix}_final_output'] = hidden_states.detach().cpu() + + return hidden_states + + # Replace processor + attn_module.set_processor(InstrumentedProcessor()) + + def register_attention_weight_hooks(self, block_indices=None): + """ + Capture actual attention weights (scores). + Warning: This uses manual attention computation, not optimized kernels. + """ + if block_indices is None: + block_indices = range(len(self.model.blocks)) + + for i in block_indices: + block = self.model.blocks[i] + self._hook_attention_with_weights(block.attn1, f'block_{i}_attn1') + self._hook_attention_with_weights(block.attn2, f'block_{i}_attn2') + + def _hook_attention_with_weights(self, attn_module, prefix): + """Hook that computes attention manually to capture weights""" + + attention_states = self.attention_states + + class WeightCapturingProcessor: + def __call__(self, attn, hidden_states, encoder_hidden_states=None, + attention_mask=None, rotary_emb=None, **kwargs): + + # [Same Q, K, V projection code as above...] + # ... (omitted for brevity) + + # Manual attention computation + import math + + # Scaled dot-product attention + scale = 1.0 / math.sqrt(query.shape[-1]) + + # (B, seq_q, heads, head_dim) @ (B, seq_k, heads, head_dim).T + # -> (B, heads, seq_q, seq_k) + attn_weights = torch.einsum('bqhd,bkhd->bhqk', query, key) * scale + + attention_states[f'{prefix}_attention_scores'] = attn_weights.detach().cpu() + + # Apply mask if present + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # Softmax + attn_weights = torch.softmax(attn_weights, dim=-1) + + attention_states[f'{prefix}_attention_weights'] = attn_weights.detach().cpu() + + # Apply attention to values + # (B, heads, seq_q, seq_k) @ (B, seq_k, heads, head_dim) + hidden_states = torch.einsum('bhqk,bkhd->bqhd', attn_weights, value) + + # Continue with output projection... + hidden_states = hidden_states.flatten(2, 3) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + attn_module.set_processor(WeightCapturingProcessor()) + + def restore_processors(self): + """Restore original attention processors""" + for prefix, original_processor in self.original_processors.items(): + # Parse prefix to get module + parts = prefix.split('_') + block_idx = int(parts[1]) + attn_name = parts[2] + + if attn_name == 'attn1': + self.model.blocks[block_idx].attn1.set_processor(original_processor) + elif attn_name == 'attn2': + self.model.blocks[block_idx].attn2.set_processor(original_processor) + + def get_attention_states(self): + """Get all captured attention states""" + return self.attention_states + + def clear_states(self): + """Clear captured states""" + self.attention_states = OrderedDict() + +if __name__ == "__main__": + # test_dit_output() + test_dit() diff --git a/bonsai/models/wan2/tests/test_umt5_output.py b/bonsai/models/wan2/tests/test_umt5_output.py new file mode 100644 index 00000000..bcf5991f --- /dev/null +++ b/bonsai/models/wan2/tests/test_umt5_output.py @@ -0,0 +1,394 @@ +"""Test output correctness by comparing JAX implementation with HuggingFace reference.""" + +import jax +import jax.numpy as jnp +import numpy as np +from huggingface_hub import snapshot_download +from bonsai.models.wan2 import params, umt5 +import torch +from transformers import AutoTokenizer, UMT5EncoderModel, UMT5ForConditionalGeneration +import os + +def check_weight_loading(jax_model, torch_model): + torch_emb = torch_model.shared.weight.detach().cpu().numpy() + jax_emb = np.array(jax_model.encoder.token_embedding.embedding.value) + + print("Embedding weights:") + print(f" Shapes: torch={torch_emb.shape}, jax={jax_emb.shape}") + print(f" Max diff: {np.abs(torch_emb - jax_emb).max():.2e}") + print(f" Mean diff: {np.abs(torch_emb - jax_emb).mean():.2e}") + + torch_q = torch_model.encoder.block[0].layer[0].SelfAttention.q.weight.detach().cpu().numpy() + jax_q = np.array(jax_model.encoder.blocks[0].attn.q.kernel.value) + + print("\nFirst block query weight:") + print(f" Shapes: torch={torch_q.shape}, jax={jax_q.shape}") + print(f" Max diff: {np.abs(torch_q.T - jax_q).max():.2e}") + print(f" Mean diff: {np.abs(torch_q.T - jax_q).mean():.2e}") + + torch_ln_weight = torch_model.encoder.final_layer_norm.weight.detach().cpu().numpy() + jax_ln_weight = np.array(jax_model.encoder.norm.weight.value) + + print("\nFinal LayerNorm weight:") + print(f" Shapes: torch={torch_ln_weight.shape}, jax={jax_ln_weight.shape}") + print(f" Max diff: {np.abs(torch_ln_weight - jax_ln_weight).max():.2e}") + print(f" Mean diff: {np.abs(torch_ln_weight - jax_ln_weight).mean():.2e}") + +def compare_outputs(jax_output: jax.Array, torch_output, name: str, rtol: float = 1e-2, atol: float = 1e-4): + if torch_output.dtype == torch.bfloat16: + torch_output = torch_output.float() + + # Convert PyTorch to numpy + if isinstance(torch_output, torch.Tensor): + torch_np = torch_output.detach().cpu().numpy() + else: + torch_np = np.array(torch_output) + + # Convert JAX to numpy + jax_np = np.array(jax_output) + + print(f"\n{'=' * 80}") + print(f"Comparing: {name}") + print(f"{'=' * 80}") + print(f"JAX shape: {jax_np.shape}") + print(f"Torch shape: {torch_np.shape}") + print(f"JAX dtype: {jax_np.dtype}") + print(f"Torch dtype: {torch_np.dtype}") + + # Check shapes match + if jax_np.shape != torch_np.shape: + print("Shape mismatch!") + return False + + # Compute differences + abs_diff = np.abs(jax_np - torch_np) + rel_diff = abs_diff / (np.abs(torch_np) + 1e-10) + + max_abs_diff = np.max(abs_diff) + max_rel_diff = np.max(rel_diff) + mean_abs_diff = np.mean(abs_diff) + mean_rel_diff = np.mean(rel_diff) + + print("\nStatistics:") + print(f" Max absolute difference: {max_abs_diff:.2e}") + print(f" Max relative difference: {max_rel_diff:.2e}") + print(f" Mean absolute difference: {mean_abs_diff:.2e}") + print(f" Mean relative difference: {mean_rel_diff:.2e}") + + print(f"\nJAX output range: [{np.min(jax_np):.4f}, {np.max(jax_np):.4f}]") + print(f"Torch output range: [{np.min(torch_np):.4f}, {np.max(torch_np):.4f}]") + + # Check if within tolerance + close = np.allclose(jax_np, torch_np, rtol=rtol, atol=atol) + + if close: + print(f"\nβœ… Outputs match within tolerance (rtol={rtol}, atol={atol})") + else: + print(f"\n❌ Outputs do NOT match (rtol={rtol}, atol={atol})") + # Show some mismatched locations + mismatch_mask = ~np.isclose(jax_np, torch_np, rtol=rtol, atol=atol) + n_mismatches = np.sum(mismatch_mask) + print(f" Number of mismatches: {n_mismatches} / {jax_np.size} ({100 * n_mismatches / jax_np.size:.2f}%)") + + return close + +def test_t5_encoder(): + """Test UMT5 encoder output against Wan UMT5 reference implementation.""" + print("\n" + "=" * 80) + print("TEST 1: UMT5 Encoder (UMT5-XXL)") + print("=" * 80) + # Download checkpoint + model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + + # Test prompt + prompt = "A beautiful sunset over the ocean with waves crashing on the shore" + max_length = 512 + + print(f"\nTest prompt: {prompt}") + print(f"Max length: {max_length}") + + print("\nTokenizing...") + tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + inputs_j = tokenizer(prompt, return_tensors="np") + inputs_p = tokenizer(prompt, return_tensors="pt") + + print("\n[1/2] Loading UMT5 encoder...") + jax_t5 = params.create_t5_encoder_from_safe_tensors(model_ckpt_path, mesh=None) + hf_t5 = UMT5EncoderModel.from_pretrained(model_ckpt_path, subfolder="text_encoder", torch_dtype=torch.bfloat16) + + check_weight_loading(jax_t5, hf_t5) + + print("\nRunning model...") + input_ids_jax = jnp.array(inputs_j.input_ids) + jax_output = jax_t5(input_ids_jax, deterministic=True) + + pytorch_output = hf_t5(inputs_p.input_ids) + torch_embeddings = pytorch_output.last_hidden_state + + # Compare only the valid portion (ignore padding) + return compare_outputs(jax_output, torch_embeddings, "UMT5 Encoder Output", rtol=1e-3, atol=1e-4) + + +def test_t5_intermediate(): + """Compare intermediate layer outputs between JAX and PyTorch UMT5 encoder.""" + print("\n" + "=" * 80) + print("TEST 2: UMT5 Encoder Intermediate Outputs") + print("=" * 80) + + # Download checkpoint + model_ckpt_path = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + + # Test prompt + prompt = "A beautiful sunset over the ocean with waves crashing on the shore" + + print(f"\nTest prompt: {prompt}") + + print("\nTokenizing...") + tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + inputs_j = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=512, truncation=True) + inputs_p = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=512, truncation=True) + + print("\n[1/3] Loading models...") + jax_t5 = params.create_t5_encoder_from_safe_tensors(model_ckpt_path, mesh=None) + pytorch_t5 = UMT5EncoderModel.from_pretrained( + model_ckpt_path, + subfolder="text_encoder", + torch_dtype=torch.float32 + ) + pytorch_t5.eval() + + # Register hooks to capture PyTorch intermediate outputs + print("\n[2/3] Capturing PyTorch intermediate outputs...") + pytorch_intermediates = {} + + def make_hook(name): + def hook(module, input, output): + if isinstance(output, tuple): + pytorch_intermediates[name] = output[0].detach().cpu() + else: + pytorch_intermediates[name] = output.detach().cpu() + return hook + + # Hook embedding + pytorch_t5.encoder.embed_tokens.register_forward_hook(make_hook("embeddings")) + + # Hook each block with detailed attention captures + for i, block in enumerate(pytorch_t5.encoder.block): + block.register_forward_hook(make_hook(f"block_{i}_output")) + block.layer[0].register_forward_hook(make_hook(f"block_{i}_attn_output")) + + # Hook attention Q, K, V projections + block.layer[0].layer_norm.register_forward_hook(make_hook(f"block_{i}_attn_norm")) + attn = block.layer[0].SelfAttention + attn.q.register_forward_hook(make_hook(f"block_{i}_q_proj")) + attn.k.register_forward_hook(make_hook(f"block_{i}_k_proj")) + attn.v.register_forward_hook(make_hook(f"block_{i}_v_proj")) + + if len(block.layer) > 1: + block.layer[1].register_forward_hook(make_hook(f"block_{i}_ffn_output")) + + # Hook final norm + pytorch_t5.encoder.final_layer_norm.register_forward_hook(make_hook("final_norm")) + + # Run PyTorch forward + with torch.no_grad(): + pytorch_output = pytorch_t5(inputs_p.input_ids) + + print("\n[3/3] Comparing intermediate outputs...") + + # Convert inputs to JAX + input_ids_jax = jnp.array(inputs_j.input_ids) + + # Manual forward pass for JAX to get intermediates + # 1. Embeddings + x_jax = jax_t5.encoder.token_embedding(input_ids_jax) + embeddings_torch = pytorch_intermediates["embeddings"] + compare_outputs(x_jax, embeddings_torch, "Token Embeddings", rtol=1e-4, atol=1e-6) + + # 2. Dropout (not comparing, just applying) + x_jax = jax_t5.encoder.dropout(x_jax, deterministic=True) + + # 3. Position bias setup (only once) + batch_size, seq_len = input_ids_jax.shape + position_bias_jax = None + + # 4. Process through blocks + num_layers = len(jax_t5.encoder.blocks) + print(f"\nComparing {num_layers} transformer blocks...") + + for i in range(min(3, num_layers)): # Compare first 3 blocks + print(f"\n{'='*80}") + print(f"BLOCK {i}") + print(f"{'='*80}") + + block = jax_t5.encoder.blocks[i] + + # Manual attention computation to capture Q, K, V + print(f"\n--- Attention Details ---") + + # Norm + normed_x_jax = block.norm1(x_jax) + + compare_outputs(normed_x_jax, pytorch_intermediates[f"block_{i}_attn_norm"], f"Block {i} Attention Norm", rtol=1e-5, atol=1e-6) + + # Q, K, V projections + q_jax = block.attn.q(normed_x_jax) + k_jax = block.attn.k(normed_x_jax) + v_jax = block.attn.v(normed_x_jax) + + # Compare with PyTorch + q_torch = pytorch_intermediates[f"block_{i}_q_proj"] + k_torch = pytorch_intermediates[f"block_{i}_k_proj"] + v_torch = pytorch_intermediates[f"block_{i}_v_proj"] + + compare_outputs(q_jax, q_torch, f"Block {i} Q after Linear", rtol=1e-5, atol=1e-6) + compare_outputs(k_jax, k_torch, f"Block {i} K after Linear", rtol=1e-5, atol=1e-6) + compare_outputs(v_jax, v_torch, f"Block {i} V after Linear", rtol=1e-5, atol=1e-6) + + # Now run full attention (for comparison) + if position_bias_jax is None: + attn_output = block.attn( + x_jax, + mask=None, + pos_bias=None, + deterministic=True + ) + # Compare position bias (only computed in first block) + if f"block_{i}_position_bias" in pytorch_intermediates: + pos_bias_torch = pytorch_intermediates[f"block_{i}_position_bias"] + compare_outputs(position_bias_jax, pos_bias_torch, f"Block {i} Position Bias", rtol=1e-5, atol=1e-6) + else: + attn_output = block.attn( + x_jax, + mask=None, + pos_bias=position_bias_jax, + deterministic=True + ) + + # PyTorch attention output (after layer[0] which includes norm + attn + dropout) + attn_torch = pytorch_intermediates[f"block_{i}_attn_output"] + compare_outputs(attn_output, attn_torch, f"Block {i} Attention Output", rtol=1e-3, atol=1e-5) + + x_jax = attn_output + + # FFN + if hasattr(block, 'ffn'): + ffn_output = block.ffn(x_jax, deterministic=True) + x_jax = ffn_output + + if f"block_{i}_ffn_output" in pytorch_intermediates: + ffn_torch = pytorch_intermediates[f"block_{i}_ffn_output"] + compare_outputs(ffn_output, ffn_torch, f"Block {i} FFN Output", rtol=1e-3, atol=1e-5) + + # Final block output + block_torch = pytorch_intermediates[f"block_{i}_output"] + compare_outputs(x_jax, block_torch, f"Block {i} Final Output", rtol=1e-3, atol=1e-5) + + # 5. Final layer norm + x_jax = jax_t5.encoder.norm(x_jax) + final_norm_torch = pytorch_intermediates["final_norm"] + compare_outputs(x_jax, final_norm_torch, "Final Layer Norm", rtol=1e-4, atol=1e-6) + + # 6. Final output + torch_final = pytorch_output.last_hidden_state + compare_outputs(x_jax, torch_final, "Final Output", rtol=1e-3, atol=1e-4) + + print("\n" + "="*80) + print("Intermediate comparison complete!") + print("="*80) + +def test_t5_e2e(): + """Test JAX UMT5 encoder with PyTorch decoder on end-to-end generation task.""" + print("\n" + "=" * 80) + print("TEST 2: UMT5 E2E (JAX Encoder + PyTorch Decoder)") + print("=" * 80) + # Test prompts + test_prompts = [ + "Studies have shown that good for you", + ] + + print("\n[1/3] Loading models...") + tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + model_ckpt_path = snapshot_download("google/umt5-xxl") + + # Load JAX encoder + jax_t5 = params.create_t5_encoder_from_safe_tensors(model_ckpt_path, mesh=None, is_sf=False,config=umt5.T5Config.umt5_xxl()) + + # Load full PyTorch model (encoder + decoder) + pytorch_full_model = UMT5ForConditionalGeneration.from_pretrained( + model_ckpt_path, + torch_dtype=torch.float32 + ) + pytorch_full_model.eval() + + print("\n[2/3] Running generation tests...") + + for i, prompt in enumerate(test_prompts): + print(f"\n{'='*80}") + print(f"Test Case {i+1}: {prompt}") + print(f"{'='*80}") + + # Tokenize + inputs_jax = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=512, truncation=True) + inputs_torch = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=512, truncation=True) + + # ============================================================ + # Baseline: Full PyTorch model + # ============================================================ + print("\n[Baseline] Full PyTorch model:") + with torch.no_grad(): + pytorch_outputs = pytorch_full_model.generate( + input_ids=inputs_torch.input_ids, + attention_mask=inputs_torch.attention_mask, + max_length=50, + num_beams=1, # Greedy decoding + do_sample=False, + ) + + pytorch_text = tokenizer.decode(pytorch_outputs[0]) + print(f" Output: {pytorch_text}") + + # ============================================================ + # Hybrid: JAX encoder + PyTorch decoder + # ============================================================ + print("\n[Hybrid] JAX encoder + PyTorch decoder:") + + # Get encoder hidden states from JAX + input_ids_jax = jnp.array(inputs_jax.input_ids) + jax_encoder_output = jax_t5(input_ids_jax, deterministic=True) + + # Convert to PyTorch + encoder_hidden_states = torch.from_numpy(np.array(jax_encoder_output)) + + print(f" JAX encoder output shape: {encoder_hidden_states.shape}") + print(f" JAX encoder output range: [{encoder_hidden_states.min():.4f}, {encoder_hidden_states.max():.4f}]") + + # Create encoder outputs object for decoder + from transformers.modeling_outputs import BaseModelOutput + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_hidden_states + ) + + # Generate using decoder with JAX encoder outputs + with torch.no_grad(): + hybrid_outputs = pytorch_full_model.generate( + encoder_outputs=encoder_outputs, + attention_mask=inputs_torch.attention_mask, + max_length=50, + num_beams=1, + do_sample=False, + ) + + hybrid_text = tokenizer.decode(hybrid_outputs[0]) + print(f" Output: {hybrid_text}") + + print("\n[3/3] Summary") + print("=" * 80) + print("E2E test complete. Check outputs above to verify encoder correctness.") + + +if __name__ == "__main__": + # Uncomment the test you want to run: + test_t5_encoder() # Test final outputs only + # test_t5_intermediate() # Test intermediate layer outputs (detailed) + test_t5_e2e() # End-to-end generation test diff --git a/bonsai/models/wan2/tests/test_vae_output.py b/bonsai/models/wan2/tests/test_vae_output.py new file mode 100644 index 00000000..2ef0333f --- /dev/null +++ b/bonsai/models/wan2/tests/test_vae_output.py @@ -0,0 +1,426 @@ +import jax +import jax.numpy as jnp +import numpy as np +from modelscope import snapshot_download as ms_snapshot_download +from huggingface_hub import snapshot_download as hf_snapshot_download +from bonsai.models.wan2 import params +from bonsai.models.wan2 import vae_wan as vae_lib +import torch +from diffusers import AutoencoderKLWan +from jax.lax import Precision +from collections import OrderedDict +from flax import nnx +import sys +import torch + +class WanVAEDecoderHooks: + """Extract intermediate outputs from Wan VAE Decoder using forward hooks""" + + def __init__(self, vae): + self.vae = vae + self.decoder = vae.decoder + self.outputs = OrderedDict() + self.hooks = [] + self.capture_enabled = True # Control when to capture + + def register_decoder_hooks(self): + """Register hooks on all decoder layers""" + + # 1. Hook post_quant_conv (before decoder) + h = self.vae.post_quant_conv.register_forward_hook( + self._make_conditional_hook('post_quant_conv') + ) + self.hooks.append(h) + + # 2. Hook conv_in + h = self.decoder.conv_in.register_forward_hook( + self._make_conditional_hook('conv_in') + ) + self.hooks.append(h) + + # 3. Hook mid_block + h = self.decoder.mid_block.register_forward_hook( + self._make_conditional_hook('mid_block') + ) + self.hooks.append(h) + + # 4. Hook each mid_block residual block + if hasattr(self.decoder.mid_block, 'resnets'): + for i, res_block in enumerate(self.decoder.mid_block.resnets): + h = res_block.register_forward_hook( + self._make_conditional_hook(f'mid_block_res_{i}') + ) + self.hooks.append(h) + if hasattr(self.decoder.mid_block, 'attentions'): + for i, attn_block in enumerate(self.decoder.mid_block.attentions): + h = attn_block.register_forward_hook( + self._make_conditional_hook(f'mid_block_attn_{i}') + ) + self.hooks.append(h) + + # 5. Hook each up_block + for i, up_block in enumerate(self.decoder.up_blocks): + h = up_block.register_forward_hook( + self._make_conditional_hook(f'up_block_{i}') + ) + self.hooks.append(h) + + # Hook residual blocks within up_block + if hasattr(up_block, 'resnets'): + for j, res_block in enumerate(up_block.resnets): + h = res_block.register_forward_hook( + self._make_conditional_hook(f'up_block_{i}_res_{j}') + ) + self.hooks.append(h) + + # Hook upsample layers + if hasattr(up_block, 'upsamplers') and up_block.upsamplers is not None: + h = up_block.upsamplers[0].register_forward_hook( + self._make_conditional_hook(f'up_block_{i}_upsample') + ) + self.hooks.append(h) + + # 6. Hook norm_out + h = self.decoder.norm_out.register_forward_hook( + self._make_conditional_hook('norm_out') + ) + self.hooks.append(h) + + # 7. Hook nonlinearity (after norm_out) + # We need to hook this differently since it's a function, not a module + # We'll hook conv_out input instead + + # 8. Hook conv_out (final output) + h = self.decoder.conv_out.register_forward_hook( + self._make_conditional_hook('conv_out') + ) + self.hooks.append(h) + + def _make_conditional_hook(self, name): + """Create hook that only captures when enabled""" + def hook(module, input, output): + if self.capture_enabled: + self.outputs[name] = output.detach().cpu() + return hook + + + def decode_first_frame_only(self, latents): + """Decode only first frame with hooks enabled""" + + # Extract first frame + z_first = latents[:, :, 0:1, :, :] # (B, C, 1, H, W) + + # Enable capture + self.capture_enabled = True + + # Run through decoder directly (bypass frame loop) + with torch.no_grad(): + x = self.vae.post_quant_conv(z_first) + out = self.decoder( + x, + feat_cache=None, # No cache for single frame + feat_idx=[0], + first_chunk=True + ) + + # Disable capture + self.capture_enabled = False + + return out + + + def _make_hook(self, name): + """Create a hook function with closure over name""" + def hook(module, input, output): + self.outputs[name] = output.detach().cpu() + return hook + + def remove_hooks(self): + """Remove all registered hooks""" + for h in self.hooks: + h.remove() + self.hooks = [] + + def clear_outputs(self): + """Clear captured outputs""" + self.outputs = OrderedDict() + + def get_outputs(self): + """Get all captured outputs""" + return self.outputs + +def test_vae_decoder_outpout(src:str="hf"): + # Load VAE model + print("Loading AutoencoderKLWan...") + if src == "ms": + model_ckpt_path = ms_snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers",allow_patterns='vae/*') + elif src == "hf": + model_ckpt_path = hf_snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers",allow_patterns='vae/*') + + vae_jax = params.create_vae_decoder_from_safe_tensors(model_ckpt_path, mesh=None) + vae = AutoencoderKLWan.from_pretrained( + model_ckpt_path, + subfolder="vae", + torch_dtype=torch.float32 + ) + vae.eval() + + # Register hooks + hook_manager = WanVAEDecoderHooks(vae) + hook_manager.register_decoder_hooks() + + # Create dummy latent input + torch.manual_seed(42) + batch_size = 1 + z_dim = 16 + num_frames = 9 + height = 30 # After spatial compression (8x) + width = 52 + + latents_mean = ( + torch.tensor(vae_lib.VAEConfig.latent_mean) + .view(1, 16, 1, 1, 1) + .to(dtype=torch.float32) + ) + latents_std = 1.0 / torch.tensor(vae_lib.VAEConfig.latent_std).view( + 1, 16, 1, 1, 1 + ).to(dtype=torch.float32) + + latents_original = torch.randn(batch_size, z_dim, num_frames, height, width,dtype=torch.float32) + latents = latents_original / latents_std + latents_mean + latents_jax = jnp.array(latents_original.numpy().transpose(0,2,3,4,1)) + + print(f"\nInput latents shape: {latents.shape}") + print("Running decoder forward pass...\n") + + decoded_jax = vae_jax.decode(latents_jax) + print(decoded_jax[0,1:,:,235:,:].mean()) + with torch.no_grad(): + decoded = vae.decode(latents).sample + compare_outputs(decoded_jax, decoded, "final_output", rtol=1e-2, atol=1e-4) + hook_manager.remove_hooks() + +def test_vae_decoder(src:str="hf"): + # Load VAE model + print("Loading AutoencoderKLWan...") + if src == "ms": + model_ckpt_path = ms_snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers",allow_patterns='vae/*') + elif src == "hf": + model_ckpt_path = hf_snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers",allow_patterns='vae/*') + + vae_jax = params.create_vae_decoder_from_safe_tensors(model_ckpt_path, mesh=None) + vae = AutoencoderKLWan.from_pretrained( + model_ckpt_path, + subfolder="vae", + torch_dtype=torch.float32 + ) + vae.eval() + + # Register hooks + hook_manager = WanVAEDecoderHooks(vae) + hook_manager.register_decoder_hooks() + + # Create dummy latent input + torch.manual_seed(42) + batch_size = 1 + z_dim = 16 + num_frames = 9 + height = 30 # After spatial compression (8x) + width = 52 + + latents_mean = ( + torch.tensor(vae_lib.VAEConfig.latent_mean) + .view(1, 16, 1, 1, 1) + .to(dtype=torch.float32) + ) + latents_std = 1.0 / torch.tensor(vae_lib.VAEConfig.latent_std).view( + 1, 16, 1, 1, 1 + ).to(dtype=torch.float32) + + latents_original = torch.randn(batch_size, z_dim, num_frames, height, width,dtype=torch.float32) + latents = latents_original / latents_std + latents_mean + latents_jax = jnp.array(latents_original.numpy().transpose(0,2,3,4,1)) + + print(f"\nInput latents shape: {latents.shape}") + print("Running decoder forward pass...\n") + + print("=" * 80) + print("CAPTURED VAE DECODER INTERMEDIATE OUTPUTS") + print("=" * 80) + + output_jax = {} + z, _ = vae_jax.conv2(latents_jax, None) + compare_outputs(z, outputs['post_quant_conv'], 'post_quant_conv', rtol=1e-2, atol=1e-4) + output_jax['post_quant_conv'] = z + + t = z.shape[1] + frames = [] + decoder = vae_jax.decoder + + # Initialize cache list for feature caching + cache_list = [None] * 50 + + for i in range(t): + print(f"\n{'='*80}") + print(f"Processing frame {i+1}/{t}") + print(f"{'='*80}") + + # Reset cache index for each frame + cache_idx = [0] + frame_latent = z[:, i : i + 1, :, :, :] + + if i == 0: + idx = cache_idx[0] + x, cache_list[idx] = decoder.conv_in(frame_latent, cache_list[idx]) + cache_idx[0] += 1 + compare_outputs(x, outputs['conv_in'], 'conv_in', rtol=1e-2, atol=1e-4) + output_jax['conv_in'] = x + + # Mid block 1 + x, cache_list = decoder.mid_block1(x, cache_list, cache_idx) + compare_outputs(x, outputs['mid_block_res_0'], 'mid_block_res_0', rtol=1e-2, atol=1e-4) + output_jax['mid_block_res_0'] = x + + # Mid attention + x = decoder.mid_attn(x) + compare_outputs(x, outputs['mid_block_attn_0'], 'mid_block_attn_0', rtol=1e-2, atol=1e-4) + output_jax['mid_block_attn_0'] = x + + # Mid block 2 + x, cache_list = decoder.mid_block2(x, cache_list, cache_idx) + compare_outputs(x, outputs['mid_block_res_1'], 'mid_block_res_1', rtol=1e-2, atol=1e-4) + output_jax['mid_block_res_1'] = x + + # Upsample stage 0 + for j, block in enumerate(decoder.up_blocks_0): + x, cache_list = block(x, cache_list, cache_idx) + compare_outputs(x, outputs[f'up_block_0_res_{j}'], f'up_block_0_res_{j}', rtol=1e-2, atol=1e-4) + output_jax[f'up_block_0_res_{j}'] = x + x, cache_list = decoder.up_sample_0(x, cache_list, cache_idx) + compare_outputs(x, outputs['up_block_0_upsample'], 'up_block_0_upsample', rtol=1e-2, atol=1e-4) + output_jax['up_block_0_upsample'] = x + + # Upsample stage 1 + for j, block in enumerate(decoder.up_blocks_1): + x, cache_list = block(x, cache_list, cache_idx) + compare_outputs(x, outputs[f'up_block_1_res_{j}'], f'up_block_1_res_{j}', rtol=1e-2, atol=1e-4) + output_jax[f'up_block_1_res_{j}'] = x + x, cache_list = decoder.up_sample_1(x, cache_list, cache_idx) + compare_outputs(x, outputs['up_block_1_upsample'], 'up_block_1_upsample', rtol=1e-2, atol=1e-4) + output_jax['up_block_1_upsample'] = x + + # Upsample stage 2 (spatial only, no cache) + for j, block in enumerate(decoder.up_blocks_2): + x, cache_list = block(x, cache_list, cache_idx) + compare_outputs(x, outputs[f'up_block_2_res_{j}'], f'up_block_2_res_{j}', rtol=1e-2, atol=1e-4) + output_jax[f'up_block_2_res_{j}'] = x + x = decoder.up_sample_2(x) # Spatial-only, no cache + compare_outputs(x, outputs['up_block_2_upsample'], 'up_block_2_upsample', rtol=1e-2, atol=1e-4) + output_jax['up_block_2_upsample'] = x + + # Upsample stage 3 (no spatial upsample) + for j, block in enumerate(decoder.up_blocks_3): + x, cache_list = block(x, cache_list, cache_idx) + compare_outputs(x, outputs[f'up_block_3_res_{j}'], f'up_block_3_res_{j}', rtol=1e-2, atol=1e-4) + output_jax[f'up_block_3_res_{j}'] = x + + x = decoder.norm_out(x) + compare_outputs(x, outputs['norm_out'], 'norm_out', rtol=1e-2, atol=1e-4) + output_jax['norm_out'] = x + x = nnx.silu(x) + + idx = cache_idx[0] + x, cache_list[idx] = decoder.conv_out(x, cache_list[idx]) + cache_idx[0] += 1 + compare_outputs(x, outputs['conv_out'], 'conv_out', rtol=1e-2, atol=1e-4) + output_jax['conv_out'] = x + frames.append(x) + else: + # Subsequent frames: use cached features + print(f"Subsequent frame: using feature cache") + frame_out, cache_list = decoder(frame_latent, cache_list, cache_idx) + frames.append(frame_out) + + print("\n" + "=" * 80) + print(f"Final decoded output shape: {decoded.shape}") + print("=" * 80) + + # Save outputs for comparison + outputs_dict = { + 'inputs': { + 'latents': latents.cpu(), + }, + 'intermediate': outputs, + 'output': decoded.cpu(), + } + + outputs_dict_jax = { + 'intermediate': output_jax + } + + compare_with_jax_decoder(outputs_dict, outputs_dict_jax) + + torch.save(outputs_dict, 'wan_vae_decoder_outputs.pt') + print("\nβœ“ Saved outputs to 'wan_vae_decoder_outputs.pt'") + + # Clean up + hook_manager.remove_hooks() + +def compare_outputs(jax_output: jax.Array, torch_output, name: str, rtol: float = 1e-2, atol: float = 1e-4): + if torch_output.dtype == torch.bfloat16: + torch_output = torch_output.float() + + if isinstance(torch_output, torch.Tensor): + torch_np = torch_output.detach().cpu().numpy() + else: + torch_np = np.array(torch_output) + + jax_np = np.array(jax_output).transpose(0,4,1,2,3) # Convert JAX [B,T,H,W,C] to [B,C,T,H,W] + + print(f"\n{'=' * 80}") + print(f"Comparing: {name}") + print(f"{'=' * 80}") + print(f"JAX shape: {jax_np.shape}") + print(f"Torch shape: {torch_np.shape}") + print(f"JAX dtype: {jax_np.dtype}") + print(f"Torch dtype: {torch_np.dtype}") + + if jax_np.shape != torch_np.shape: + print("Shape mismatch!") + return False + + abs_diff = np.abs(jax_np - torch_np) + rel_diff = abs_diff / (np.abs(torch_np) + 1e-10) + + max_abs_diff = np.max(abs_diff) + max_rel_diff = np.max(rel_diff) + mean_abs_diff = np.mean(abs_diff) + mean_rel_diff = np.mean(rel_diff) + + print("\nStatistics:") + print(f" Max absolute difference: {max_abs_diff:.2e}") + print(f" Max relative difference: {max_rel_diff:.2e}") + print(f" Mean absolute difference: {mean_abs_diff:.2e}") + print(f" Mean relative difference: {mean_rel_diff:.2e}") + + print(f"\nJAX output range: [{np.min(jax_np):.4f}, {np.max(jax_np):.4f}]") + print(f"Torch output range: [{np.min(torch_np):.4f}, {np.max(torch_np):.4f}]") + + close = np.allclose(jax_np, torch_np, rtol=rtol, atol=atol) + + if close: + print(f"\nβœ… Outputs match within tolerance (rtol={rtol}, atol={atol})") + else: + print(f"\n❌ Outputs do NOT match (rtol={rtol}, atol={atol})") + # Show some mismatched locations + mismatch_mask = ~np.isclose(jax_np, torch_np, rtol=rtol, atol=atol) + n_mismatches = np.sum(mismatch_mask) + print(f" Number of mismatches: {n_mismatches} / {jax_np.size} ({100 * n_mismatches / jax_np.size:.2f}%)") + + return close + +if __name__ == "__main__": + # add args for modelscope/huggingface model download + src = sys.argv[1] if len(sys.argv) > 1 else "hf" + assert src in ["ms", "hf"], "Invalid source specified. Use 'modelscope' or 'huggingface'." + test_vae_decoder_output(src) \ No newline at end of file diff --git a/bonsai/models/wan2/transformer_wan.py b/bonsai/models/wan2/transformer_wan.py new file mode 100644 index 00000000..8c9e204a --- /dev/null +++ b/bonsai/models/wan2/transformer_wan.py @@ -0,0 +1,461 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wan2.1-T2V-1.3B: Text-to-Video Diffusion Transformer Model. + +This implements the Wan2.1-T2V-1.3B model, a 1.3B parameter diffusion transformer +for text-to-video generation using Flow Matching framework. + +Architecture: +- 30-layer Diffusion Transformer with 1536 hidden dim +- 12 attention heads (128 dim each) +- Vision self-attention + text cross-attention +- AdaLN modulation conditioned on timestep +- UMT5 text encoder for multilingual prompts +- Wan-VAE for video encoding/decoding +""" + +import dataclasses +import math +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import nnx +from jax.lax import Precision +from jaxtyping import Array + +from .unipc_multistep_scheduler import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState + + +@dataclasses.dataclass(frozen=True) +class TransformerWanModelConfig: + """Configuration for Wan2.1-T2V-1.3B Diffusion Transformer.""" + + weights_dtype: jnp.dtype = jnp.bfloat16 + num_layers: int = 30 + hidden_dim: int = 1536 + latent_input_dim: int = 16 + latent_output_dim: int = 16 + ffn_dim: int = 8960 + freq_dim: int = 256 + num_heads: int = 12 + head_dim: int = 128 + text_embed_dim: int = 4096 + max_text_len: int = 512 + num_frames: int = 21 + latent_size: Tuple[int, int] = (30, 30) + num_inference_steps: int = 50 + guidance_scale: float = 5.0 + patch_size: Tuple[int, int, int] = (1, 2, 2) + cross_attn_norm: bool = True + qk_norm: Optional[str] = "rms_norm_across_heads" + eps: float = 1e-6 + added_kv_proj_dim: Optional[int] = None # None for T2V, set for I2V + rope_max_seq_len: int = 1024 + + def __post_init__(self): + assert self.hidden_dim == self.num_heads * self.head_dim, "hidden_dim must equal num_heads * head_dim" + + +class TimestepEmbedding(nnx.Module): + """Timestep embedding: sinusoidal -> MLP -> projection for AdaLN.""" + + def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + self.time_embedding = nnx.Sequential( + nnx.Linear(cfg.freq_dim, cfg.hidden_dim, rngs=rngs), + nnx.silu, + nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs), + ) + self.time_projection = nnx.Sequential( + nnx.silu, + nnx.Linear(cfg.hidden_dim, 6 * cfg.hidden_dim, rngs=rngs), + ) + + def __call__(self, t: Array) -> tuple[Array, Array]: + t_freq = sinusoidal_embedding_1d(t, self.cfg.freq_dim) + time_emb = self.time_embedding(t_freq) + time_proj = self.time_projection(time_emb) + return time_emb, time_proj + + +def sinusoidal_embedding_1d(timesteps: Array, embedding_dim: int, max_period: int = 10000) -> Array: + half_dim = embedding_dim // 2 + freqs = jnp.exp(-math.log(max_period) * jnp.arange(0, half_dim) / half_dim) + args = timesteps[:, None] * freqs[None, :] + embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) + if embedding_dim % 2: + embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1) + return embedding + + +def precompute_freqs_cis_3d(dim: int, theta: float = 10000.0, max_seq_len: int = 1024) -> tuple[Array, Array, Array]: + """Precompute 3D RoPE frequencies split as T: dim-4*(dim//6), H: 2*(dim//6), W: 2*(dim//6).""" + dim_base = dim // 6 + dim_t, dim_h, dim_w = dim - 4 * dim_base, 2 * dim_base, 2 * dim_base + assert dim_t + dim_h + dim_w == dim + return ( + rope_params(max_seq_len, dim_t, theta), + rope_params(max_seq_len, dim_h, theta), + rope_params(max_seq_len, dim_w, theta), + ) + + +def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> Array: + freqs = 1.0 / jnp.power(theta, jnp.arange(0, dim, 2, dtype=jnp.float32) / dim) + positions = jnp.arange(max_seq_len, dtype=jnp.float32) + freqs = jnp.outer(positions, freqs) + return jnp.stack([jnp.cos(freqs), jnp.sin(freqs)], axis=-1) + + +def rope_apply(x: Array, grid_sizes: tuple[int, int, int], freqs: tuple[Array, Array, Array]) -> Array: + """Apply 3D RoPE to input tensor.""" + b, seq_len, num_heads, head_dim = x.shape + f, h, w = grid_sizes + freqs_t, freqs_h, freqs_w = freqs + + dim_base = head_dim // 6 + dim_t, dim_h, dim_w = head_dim - 4 * dim_base, 2 * dim_base, 2 * dim_base + + freqs_grid = jnp.concatenate( + [ + jnp.broadcast_to(freqs_t[:f, None, None, :, :], (f, h, w, dim_t // 2, 2)), + jnp.broadcast_to(freqs_h[None, :h, None, :, :], (f, h, w, dim_h // 2, 2)), + jnp.broadcast_to(freqs_w[None, None, :w, :, :], (f, h, w, dim_w // 2, 2)), + ], + axis=3, + ).reshape(seq_len, head_dim // 2, 2)[None, :, None, :, :] + + x_complex = x.reshape(b, seq_len, num_heads, head_dim // 2, 2) + x_out = jnp.stack( + [ + x_complex[..., 0] * freqs_grid[..., 0] - x_complex[..., 1] * freqs_grid[..., 1], + x_complex[..., 0] * freqs_grid[..., 1] + x_complex[..., 1] * freqs_grid[..., 0], + ], + axis=-1, + ) + + return x_out.reshape(b, seq_len, num_heads, head_dim) + + +class WanLayerNorm(nnx.LayerNorm): + """LayerNorm with float32 conversion for numerical stability.""" + + def __init__(self, dim: int, eps: float = 1e-6, use_scale: bool = False, use_bias: bool = False, *, rngs: nnx.Rngs): + super().__init__(dim, epsilon=eps, use_scale=use_scale, use_bias=use_bias, rngs=rngs) + + def __call__(self, x: Array) -> Array: + dtype = x.dtype + return super().__call__(x.astype(jnp.float32)).astype(dtype) + + +class MultiHeadAttention(nnx.Module): + def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs): + self.num_heads, self.head_dim = cfg.num_heads, cfg.head_dim + self.q_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST) + self.k_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST) + self.v_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST) + self.out_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST) + self.q_norm = nnx.RMSNorm(cfg.hidden_dim, rngs=rngs) + self.k_norm = nnx.RMSNorm(cfg.hidden_dim, rngs=rngs) + + def __call__(self, x: Array, rope_state: tuple | None = None, deterministic: bool = True) -> Array: + b, n = x.shape[:2] + q = self.q_norm(self.q_proj(x)).reshape(b, n, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.k_norm(self.k_proj(x)).reshape(b, n, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.v_proj(x).reshape(b, n, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + + if rope_state is not None: + freqs, grid_sizes = rope_state + q, k = jnp.transpose(q, (0, 2, 1, 3)), jnp.transpose(k, (0, 2, 1, 3)) + q, k = rope_apply(q, grid_sizes, freqs), rope_apply(k, grid_sizes, freqs) + q, k = jnp.transpose(q, (0, 2, 1, 3)), jnp.transpose(k, (0, 2, 1, 3)) + + attn = jax.nn.softmax( + jnp.einsum("bhid,bhjd->bhij", q, k, precision=Precision.HIGHEST) / math.sqrt(self.head_dim), axis=-1 + ) + out = ( + jnp.einsum("bhij,bhjd->bhid", attn, v, precision=Precision.HIGHEST).transpose(0, 2, 1, 3).reshape(b, n, -1) + ) + return self.out_proj(out) + + +class CrossAttention(nnx.Module): + def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs): + self.num_heads, self.head_dim = cfg.num_heads, cfg.head_dim + self.q_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST) + self.kv_proj = nnx.Linear(cfg.hidden_dim, 2 * cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST) + self.out_proj = nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST) + self.q_norm = nnx.RMSNorm(cfg.hidden_dim, rngs=rngs) + self.k_norm = nnx.RMSNorm(cfg.hidden_dim, rngs=rngs) + + def __call__(self, x: Array, context: Array, deterministic: bool = True) -> Array: + b, n, m = x.shape[0], x.shape[1], context.shape[1] + q = self.q_norm(self.q_proj(x)).reshape(b, n, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + kv = self.kv_proj(context) + k, v = jnp.split(kv, 2, axis=-1) + k = self.k_norm(k).reshape(b, m, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + v = v.reshape(b, m, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + attn = jax.nn.softmax( + jnp.einsum("bhid,bhjd->bhij", q, k, precision=Precision.HIGHEST) / math.sqrt(self.head_dim), axis=-1 + ) + out = ( + jnp.einsum("bhij,bhjd->bhid", attn, v, precision=Precision.HIGHEST).transpose(0, 2, 1, 3).reshape(b, n, -1) + ) + return self.out_proj(out) + + +def modulate(x: Array, shift: Array, scale: Array) -> Array: + """Apply adaptive layer norm modulation.""" + original_dtype = x.dtype + + return (x.astype(jnp.float32) * (1 + scale) + shift).astype(original_dtype) + + +class WanAttentionBlock(nnx.Module): + """ + Wan Diffusion Transformer Block. + + Includes: + - Vision self-attention with AdaLN modulation + - Text-to-vision cross-attention + - Feed-forward network with AdaLN modulation + """ + + def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + + self.norm1 = WanLayerNorm(cfg.hidden_dim, rngs=rngs) + self.norm2 = WanLayerNorm(cfg.hidden_dim, rngs=rngs, use_scale=True, use_bias=True) + self.norm3 = WanLayerNorm(cfg.hidden_dim, rngs=rngs) + + self.self_attn = MultiHeadAttention(cfg, rngs=rngs) + self.cross_attn = CrossAttention(cfg, rngs=rngs) + + self.mlp = nnx.Sequential( + nnx.Linear(cfg.hidden_dim, cfg.ffn_dim, rngs=rngs, precision=Precision.HIGHEST), + nnx.gelu, + nnx.Linear(cfg.ffn_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST), + ) + + self.scale_shift_table = nnx.Param( + jax.random.normal(rngs.params(), (1, 6, cfg.hidden_dim)) / (cfg.hidden_dim**0.5) + ) + + @jax.named_scope("wan_attention_block") + def __call__( + self, + x: Array, + text_embeds: Array, + time_proj: Array, + rope_state: tuple | None = None, + deterministic: bool = True, + ) -> Array: + """ + Args: + x: [B, N, D] video tokens + text_embeds: [B, M, text_dim] text embeddings + time_proj: [B, 6*D] time embedding + rope_state: Optional tuple of (freqs, grid_sizes) for 3D RoPE + deterministic: Whether to apply dropout + Returns: + [B, N, D] transformed tokens + """ + # Get modulation from time embedding + b = time_proj.shape[0] + d = self.cfg.hidden_dim + reshaped_time_proj = time_proj.reshape(b, 6, d) + modulation = reshaped_time_proj.astype(jnp.float32) + self.scale_shift_table.value + modulation = modulation.reshape(b, -1) # [B, 6*D] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(modulation, 6, axis=-1) + + # Self-attention with AdaLN modulation and RoPE + norm_x = self.norm1(x) + norm_x = modulate(norm_x, shift_msa[:, None, :], scale_msa[:, None, :]) + attn_out = self.self_attn(norm_x, rope_state=rope_state, deterministic=deterministic) + x = (x.astype(jnp.float32) + gate_msa[:, None, :] * attn_out).astype(x.dtype) + + # Cross-attention + norm_x = self.norm2(x) + cross_out = self.cross_attn(norm_x, text_embeds, deterministic=deterministic) + x = x + cross_out + + # MLP with AdaLN modulation + norm_x = self.norm3(x) + norm_x = modulate(norm_x, shift_mlp[:, None, :], scale_mlp[:, None, :]) + mlp_out = self.mlp(norm_x) + x = (x.astype(jnp.float32) + gate_mlp[:, None, :] * mlp_out).astype(jnp.float32) + + return x + + +class FinalLayer(nnx.Module): + """Final layer that predicts noise from DiT output.""" + + def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + self.norm = WanLayerNorm(cfg.hidden_dim, rngs=rngs) + out_dim = math.prod(cfg.patch_size) * cfg.latent_output_dim # expand out_dim here for unpatchify + self.linear = nnx.Linear(cfg.hidden_dim, out_dim, rngs=rngs, precision=Precision.HIGHEST) + + self.scale_shift_table = nnx.Param( + jax.random.normal(rngs.params(), (1, 2, cfg.hidden_dim)) / (cfg.hidden_dim**0.5) + ) + + @jax.named_scope("final_layer") + def __call__(self, x: Array, time_emb: Array) -> Array: + """ + Args: + x: [B, N, D] DiT output + time_emb: [B, D] time embedding from TimestepEmbedding + Returns: + [B, N, latent_output_dim] predicted noise + """ + # [B, D] β†’ [B, 1, D] + [1, 2, D] β†’ [B, 2, D] + e = self.scale_shift_table.value + time_emb[:, None, :] + shift, scale = e[:, 0, :], e[:, 1, :] + + x = modulate(x, shift[:, None, :], scale[:, None, :]) + x = self.linear(x) + return x + + +class Wan2DiT(nnx.Module): + """ + Wan2.1-T2V-1.3B Diffusion Transformer. + """ + + def __init__(self, cfg: TransformerWanModelConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + + # 3D Conv to patchify video latents + # (T, H, W) β†’ (T, H/2, W/2) + self.patch_embed = nnx.Conv( + in_features=cfg.latent_input_dim, + out_features=cfg.hidden_dim, + kernel_size=(1, 2, 2), + strides=(1, 2, 2), + padding="VALID", + use_bias=True, + rngs=rngs, + precision=Precision.HIGHEST, + ) + + # Text embedding projection: UMT5 (4096) β†’ DiT (1536) + self.text_proj = nnx.Sequential( + nnx.Linear(cfg.text_embed_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST), + nnx.gelu, + nnx.Linear(cfg.hidden_dim, cfg.hidden_dim, rngs=rngs, precision=Precision.HIGHEST), + ) + + self.time_embed = TimestepEmbedding(cfg, rngs=rngs) + + self.blocks = nnx.List([WanAttentionBlock(cfg, rngs=rngs) for _ in range(cfg.num_layers)]) + + self.final_layer = FinalLayer(cfg, rngs=rngs) + + @jax.named_scope("wan2_dit") + @jax.jit + def forward(self, latents: Array, text_embeds: Array, timestep: Array, deterministic: bool = True) -> Array: + """ + Forward pass of the Diffusion Transformer. + + Args: + latents: [B, T, H, W, C] noisy video latents from VAE + text_embeds: [B, seq_len, 4096] from UMT5-XXL encoder (before projection) + timestep: [B] diffusion timestep (0 to num_steps) + deterministic: Whether to apply dropout + + Returns: + predicted_noise: [B, T, H, W, C] predicted noise + """ + text_embeds = self.text_proj(text_embeds) + + # Get time embeddings + # time_emb: [B, D] for FinalLayer + # time_proj: [B, 6*D] for AdaLN in blocks + time_emb, time_proj = self.time_embed(timestep) + + x = self.patch_embed(latents) + b, t_out, h_out, w_out, d = x.shape + x = x.reshape(b, t_out * h_out * w_out, d) + + grid_sizes = (t_out, h_out, w_out) + + max_seq = max(grid_sizes) + rope_freqs = tuple( + jax.lax.stop_gradient(arr) for arr in precompute_freqs_cis_3d(dim=self.cfg.head_dim, max_seq_len=max_seq) + ) + + for block in self.blocks: + x = block(x, text_embeds, time_proj, rope_state=(rope_freqs, grid_sizes), deterministic=deterministic) + + # Final projection to noise space + x = self.final_layer(x, time_emb) # [B, T*H*W, latent_output_dim] + + # Reshape back to video format + predicted_noise = self.unpatchify(x, grid_sizes) + + return predicted_noise + + def unpatchify(self, x: Array, grid_sizes: tuple[int, int, int]) -> Array: + """ + Reconstruct video tensors from patch embeddings. + + Args: + x: [B, T*H*W, patch_t*patch_h*patch_w*C] flattened patch embeddings + grid_sizes: (T_patches, H_patches, W_patches) grid dimensions + + Returns: + [B, T, H, W, C] reconstructed video tensor (channel-last) + """ + b, seq_len, feature_dim = x.shape + t_patches, h_patches, w_patches = grid_sizes + c = self.cfg.latent_output_dim + patch_size = self.cfg.patch_size + + assert seq_len == t_patches * h_patches * w_patches, ( + f"expected: seq_len={seq_len} should be {t_patches * h_patches * w_patches}" + ) + assert feature_dim == patch_size[0] * patch_size[1] * patch_size[2] * c, ( + f"expected: feature_dim={feature_dim} should be {patch_size[0] * patch_size[1] * patch_size[2] * c}" + ) + + x = x.reshape( + b, + t_patches, + h_patches, + w_patches, + patch_size[0], + patch_size[1], + patch_size[2], + c, + ) + x = jnp.einsum("bthwpqrc->btphqwrc", x) + x = x.reshape( + b, + t_patches * patch_size[0], + h_patches * patch_size[1], + w_patches * patch_size[2], + c, + ) + + return x + + +__all__ = [ + "TransformerWanModelConfig", + "Wan2DiT", +] diff --git a/bonsai/models/wan2/umt5.py b/bonsai/models/wan2/umt5.py new file mode 100644 index 00000000..97eaf55a --- /dev/null +++ b/bonsai/models/wan2/umt5.py @@ -0,0 +1,380 @@ +import dataclasses +import math +from typing import Optional + +import jax +import jax.numpy as jnp +from flax import nnx +from jax.lax import Precision +from jaxtyping import Array + + +def gelu(x: Array) -> Array: + return 0.5 * x * (1.0 + jnp.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.pow(x, 3.0)))) + + +class RMSNorm(nnx.Module): + def __init__(self, dim: int, eps: float = 1e-6, *, rngs: nnx.Rngs): + self.dim = dim + self.eps = eps + self.weight = nnx.Param(jnp.ones(dim)) + + def __call__(self, x: Array) -> Array: + # RMS normalization: x / sqrt(mean(x^2)) + variance = jnp.mean(x.astype(jnp.float32) ** 2, axis=-1, keepdims=True) + x = x * jax.lax.rsqrt(variance + self.eps) + return self.weight.value * x + + +class T5Attention(nnx.Module): + """UMT5 Multi-head attention.""" + + def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.1, *, rngs: nnx.Rngs): + assert dim_attn % num_heads == 0 + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # Linear projections + self.q = nnx.Linear(dim, dim_attn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs) + self.k = nnx.Linear(dim, dim_attn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs) + self.v = nnx.Linear(dim, dim_attn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs) + self.o = nnx.Linear(dim_attn, dim, use_bias=False, precision=Precision.HIGHEST, rngs=rngs) + self.dropout = nnx.Dropout(dropout, rngs=rngs) + + def __call__( + self, + x: Array, + context: Optional[Array] = None, + mask: Optional[Array] = None, + pos_bias: Optional[Array] = None, + deterministic: bool = True, + ) -> Array: + """ + Args: + x: [B, L1, C] query + context: [B, L2, C] key/value context, defaults to x for self-attention + mask: [B, L2] or [B, L1, L2] attention mask + pos_bias: [B, num_heads, L1, L2] position bias + """ + context = x if context is None else context + b = x.shape[0] + n = self.num_heads + c = self.head_dim + + q = self.q(x).reshape(b, -1, n, c) + k = self.k(context).reshape(b, -1, n, c) + v = self.v(context).reshape(b, -1, n, c) + + # Attention bias + attn_bias = jnp.zeros((b, n, q.shape[1], k.shape[1])) + if pos_bias is not None: + attn_bias = attn_bias + pos_bias + if mask is not None: + # Expand mask to attention shape + if mask.ndim == 2: + mask = mask[:, None, None, :] # [B, 1, 1, L2] + else: + mask = mask[:, None, :, :] # [B, 1, L1, L2] + attn_bias = jnp.where(mask == 0, jnp.finfo(x.dtype).min, attn_bias) + + attn = jnp.einsum("binc,bjnc->bnij", q, k) + attn_bias + attn = jax.nn.softmax(attn, axis=-1) + x = jnp.einsum("bnij,bjnc->binc", attn, v) + + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x, deterministic=deterministic) + return x + + +class T5FeedForward(nnx.Module): + """UMT5 Feed-forward network with gated activation.""" + + def __init__(self, dim: int, dim_ffn: int, dropout: float = 0.1, *, rngs: nnx.Rngs): + self.dim = dim + self.dim_ffn = dim_ffn + + self.gate = nnx.Linear(dim, dim_ffn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs) + self.fc1 = nnx.Linear(dim, dim_ffn, use_bias=False, precision=Precision.HIGHEST, rngs=rngs) + self.fc2 = nnx.Linear(dim_ffn, dim, use_bias=False, precision=Precision.HIGHEST, rngs=rngs) + self.dropout = nnx.Dropout(dropout, rngs=rngs) + + def __call__(self, x: Array, deterministic: bool = True) -> Array: + x = self.fc1(x) * gelu(self.gate(x)) + x = self.dropout(x, deterministic=deterministic) + x = self.fc2(x) + x = self.dropout(x, deterministic=deterministic) + return x + + +class T5RelativeEmbedding(nnx.Module): + def __init__(self, num_buckets: int, num_heads: int, bidirectional: bool, max_dist: int = 128, *, rngs: nnx.Rngs): + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + self.embedding = nnx.Embed(num_buckets, num_heads, rngs=rngs) + + def __call__(self, lq: int, lk: int) -> Array: + """Compute relative position bias. + + Args: + lq: Query sequence length + lk: Key sequence length + + Returns: + [1, num_heads, lq, lk] relative position bias + """ + q_pos = jnp.arange(lq)[:, None] + k_pos = jnp.arange(lk)[None, :] + rel_pos = k_pos - q_pos + + rel_pos_buckets = self._relative_position_bucket(rel_pos) + + # Get embeddings + rel_pos_embeds = self.embedding(rel_pos_buckets) # [lq, lk, num_heads] + rel_pos_embeds = rel_pos_embeds.transpose(2, 0, 1)[None, :, :, :] # [1, num_heads, lq, lk] + return rel_pos_embeds + + def _relative_position_bucket(self, rel_pos: Array) -> Array: + """Convert relative positions to bucket indices.""" + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).astype(jnp.int32) * num_buckets + rel_pos = jnp.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -jnp.minimum(rel_pos, jnp.zeros_like(rel_pos)) + + # Small vs large positions + max_exact = num_buckets // 2 + is_small = rel_pos < max_exact + + # Logarithmic bucketing for large positions + rel_pos_large = max_exact + ( + jnp.log(rel_pos.astype(jnp.float32) / max_exact) + / math.log(self.max_dist / max_exact) + * (num_buckets - max_exact) + ).astype(jnp.int32) + rel_pos_large = jnp.minimum(rel_pos_large, num_buckets - 1) + + rel_buckets = rel_buckets + jnp.where(is_small, rel_pos, rel_pos_large) + return rel_buckets + + +class T5SelfAttention(nnx.Module): + """UMT5 Self-attention block with feed-forward.""" + + def __init__( + self, + dim: int, + dim_attn: int, + dim_ffn: int, + num_heads: int, + num_buckets: int, + shared_pos: bool = True, + dropout: float = 0.1, + *, + rngs: nnx.Rngs, + ): + self.shared_pos = shared_pos + + self.norm1 = RMSNorm(dim, rngs=rngs) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout, rngs=rngs) + self.norm2 = RMSNorm(dim, rngs=rngs) + self.ffn = T5FeedForward(dim, dim_ffn, dropout, rngs=rngs) + + if not shared_pos: + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, rngs=rngs) + else: + self.pos_embedding = None + + def __call__( + self, x: Array, mask: Optional[Array] = None, pos_bias: Optional[Array] = None, deterministic: bool = True + ) -> Array: + # Get position bias + if self.shared_pos: + e = pos_bias + else: + e = self.pos_embedding(x.shape[1], x.shape[1]) + + x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e, deterministic=deterministic) + x = x + self.ffn(self.norm2(x), deterministic=deterministic) + return x + + +class T5Encoder(nnx.Module): + def __init__( + self, + vocab_size: int, + dim: int, + dim_attn: int, + dim_ffn: int, + num_heads: int, + num_layers: int, + num_buckets: int, + shared_pos: bool = True, + dropout: float = 0.1, + *, + rngs: nnx.Rngs, + ): + self.dim = dim + self.shared_pos = shared_pos + + self.token_embedding = nnx.Embed(vocab_size, dim, rngs=rngs) + if shared_pos: + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, rngs=rngs) + else: + self.pos_embedding = None + self.dropout = nnx.Dropout(dropout, rngs=rngs) + self.blocks = nnx.List( + [ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, rngs=rngs) + for _ in range(num_layers) + ] + ) + self.norm = RMSNorm(dim, rngs=rngs) + + def __call__(self, ids: Array, mask: Optional[Array] = None, deterministic: bool = True) -> Array: + x = self.token_embedding(ids) + x = self.dropout(x, deterministic=deterministic) + + # Compute shared position bias if needed + e = self.pos_embedding(x.shape[1], x.shape[1]) if self.shared_pos else None + + # Apply transformer blocks + for block in self.blocks: + x = block(x, mask, pos_bias=e, deterministic=deterministic) + + x = self.norm(x) + x = self.dropout(x, deterministic=deterministic) + return x + + +@dataclasses.dataclass(frozen=True) +class T5Config: + """Configuration for UMT5 Encoder.""" + + vocab_size: int = 256384 + dim: int = 4096 + dim_attn: int = 4096 + dim_ffn: int = 10240 + num_heads: int = 64 + num_layers: int = 24 + num_buckets: int = 32 + shared_pos: bool = False # UMT5 uses per-layer position embeddings + dropout: float = 0.1 + + @classmethod + def umt5_xxl(cls) -> "T5Config": + """UMT5-XXL configuration (~5B parameters).""" + return cls( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + num_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1, + ) + + @classmethod + def umt5_base(cls) -> "T5Config": + """UMT5-Base configuration (~580M parameters).""" + return cls( + vocab_size=256384, + dim=768, + dim_attn=768, + dim_ffn=2048, + num_heads=12, + num_layers=12, + num_buckets=32, + shared_pos=False, + dropout=0.1, + ) + + +class T5EncoderModel(nnx.Module): + """UMT5 Encoder-only model for text encoding. + + Supports multiple UMT5 configurations (UMT5-XXL, UMT5-Base). + """ + + def __init__(self, config: T5Config, *, rngs: nnx.Rngs): + """Initialize UMT5 encoder from config. + + Args: + config: T5Config specifying model architecture + rngs: Random number generators for initialization + """ + self.config = config + self.encoder = T5Encoder( + vocab_size=config.vocab_size, + dim=config.dim, + dim_attn=config.dim_attn, + dim_ffn=config.dim_ffn, + num_heads=config.num_heads, + num_layers=config.num_layers, + num_buckets=config.num_buckets, + shared_pos=config.shared_pos, + dropout=config.dropout, + rngs=rngs, + ) + + @classmethod + def from_config(cls, config: T5Config, *, rngs: nnx.Rngs) -> "T5EncoderModel": + """Create UMT5 encoder from configuration. + + Args: + config: T5Config instance + rngs: Random number generators + + Returns: + T5EncoderModel instance + """ + return cls(config, rngs=rngs) + + @classmethod + def umt5_xxl(cls, *, rngs: nnx.Rngs) -> "T5EncoderModel": + """Create UMT5-XXL encoder (~5B parameters). + + Args: + rngs: Random number generators + + Returns: + T5EncoderModel configured as UMT5-XXL + """ + return cls(T5Config.umt5_xxl(), rngs=rngs) + + @classmethod + def umt5_base(cls, *, rngs: nnx.Rngs) -> "T5EncoderModel": + """Create UMT5-Base encoder (~580M parameters). + + Args: + rngs: Random number generators + + Returns: + T5EncoderModel configured as UMT5-Base + """ + return cls(T5Config.umt5_base(), rngs=rngs) + + def __call__(self, input_ids: Array, attention_mask: Optional[Array] = None, deterministic: bool = True) -> Array: + """Encode text. + + Args: + input_ids: [B, L] token IDs + attention_mask: [B, L] attention mask (1 for valid tokens, 0 for padding) + deterministic: whether to disable dropout (True for inference) + + Returns: + [B, L, dim] encoded text embeddings (dim depends on config) + """ + return self.encoder(input_ids, mask=attention_mask, deterministic=deterministic) + + +__all__ = ["T5Config", "T5Encoder", "T5EncoderModel"] diff --git a/bonsai/models/wan2/unipc_multistep_scheduler.py b/bonsai/models/wan2/unipc_multistep_scheduler.py new file mode 100644 index 00000000..0a42076a --- /dev/null +++ b/bonsai/models/wan2/unipc_multistep_scheduler.py @@ -0,0 +1,811 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: reference pytorch implementation: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_unipc_multistep.py + +from functools import partial +from types import SimpleNamespace +from typing import List, Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +from .scheduling_utils import ( + CommonSchedulerState, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + add_noise_common, +) + + +@flax.struct.dataclass +class UniPCMultistepSchedulerState: + """ + Data class to hold the mutable state of the FlaxUniPCMultistepScheduler. + """ + + common: CommonSchedulerState + + # Core schedule parameters (derived from CommonSchedulerState in create_state) + sigmas: jnp.ndarray + alpha_t: jnp.ndarray + sigma_t: jnp.ndarray + lambda_t: jnp.ndarray + init_noise_sigma: float + + # History buffers for multi-step solver + # `model_outputs` stores previous converted model outputs (e.g., predicted x0 or epsilon) + timesteps: jnp.ndarray = None + model_outputs: jnp.ndarray = None + timestep_list: jnp.ndarray = None # Stores corresponding timesteps for `model_outputs` + + # State variables for tracking progress and solver order + lower_order_nums: int = 0 + last_sample: Optional[jnp.ndarray] = None # Sample from the previous predictor step + step_index: Optional[int] = None + begin_index: Optional[int] = None # Used for img2img/inpaing + this_order: int = 0 # Current effective order of the UniPC solver for this step + + @classmethod + def create( + cls, + common_state: CommonSchedulerState, + alpha_t: jnp.ndarray, + sigma_t: jnp.ndarray, + lambda_t: jnp.ndarray, + sigmas: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + ): + return cls( + common=common_state, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + sigmas=sigmas, + init_noise_sigma=init_noise_sigma, + lower_order_nums=0, + last_sample=None, + step_index=None, + begin_index=None, + this_order=0, + ) + + +class FlaxUniPCMultistepScheduler(FlaxSchedulerMixin): + """ + `FlaxUniPCMultistepScheduler` is a JAX/Flax training-free framework designed for the fast sampling of diffusion models. + It implements the UniPC (Unified Predictor-Corrector) algorithm for efficient diffusion model sampling. + """ + + dtype: jnp.dtype + + @property + def has_state(self) -> bool: + return True + + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[jnp.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: Optional[FlaxSchedulerMixin] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + use_flow_sigmas: Optional[bool] = False, + flow_shift: Optional[float] = 1.0, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", + rescale_zero_terminal_snr: bool = False, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + params = locals().copy() + params.pop("self") # 移陀 self + self.dtype = params.pop("dtype") # dtype δΈζ”Ύεœ¨ config ι‡Œ + + self.config = SimpleNamespace(**params) + + # # Validation checks from original __init__ + # if self.config.use_beta_sigmas and not is_scipy_available(): + # raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if self.config.solver_type not in ["bh1", "bh2"]: + raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}") + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> UniPCMultistepSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + if not self.config.rescale_zero_terminal_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + alphas_cumprod = common.alphas_cumprod + alphas_cumprod = alphas_cumprod.at[-1].set(2**-24) + common = common.replace(alphas_cumprod=alphas_cumprod) + + # Currently we only support VP-type noise schedule + alpha_t = jnp.sqrt(common.alphas_cumprod) + sigma_t = jnp.sqrt(1 - common.alphas_cumprod) + lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) + sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + if self.config.solver_type not in ["bh1", "bh2"]: + if self.config.solver_type in ["midpoint", "heun", "logrho"]: + self.config.solver_type = "bh2" + else: + raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}") + + return UniPCMultistepSchedulerState.create( + common_state=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + sigmas=sigmas, + init_noise_sigma=init_noise_sigma, + ) + + def set_begin_index( + self, state: UniPCMultistepSchedulerState, begin_index: int = 0 + ) -> UniPCMultistepSchedulerState: + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + """ + return state.replace(begin_index=begin_index) + + def set_timesteps( + self, + state: UniPCMultistepSchedulerState, + num_inference_steps: int, + shape: Tuple, + ) -> UniPCMultistepSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + """ + #### Copied from scheduling_dmpsolver_multistep_flax + last_timestep = self.config.num_train_timesteps + if self.config.timestep_spacing == "linspace": + timesteps = jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + # initial running values + sigmas = state.sigmas + + # TODO + # # Apply Karras/Exponential/Beta/Flow Sigmas if configured + if self.config.use_karras_sigmas: + # sigmas = _convert_to_karras_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError("`use_karras_sigmas` is not implemented in JAX version yet.") + elif self.config.use_exponential_sigmas: + # sigmas = _convert_to_exponential_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError("`use_exponential_sigmas` is not implemented in JAX version yet.") + elif self.config.use_beta_sigmas: + # sigmas = _convert_to_beta_jax(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) + raise NotImplementedError("`use_beta_sigmas` is not implemented in JAX version yet.") + if self.config.use_flow_sigmas: + alphas = jnp.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) + sigmas = 1.0 - alphas + sigmas = jnp.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() + timesteps = (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype(jnp.float32) + else: # Default case if none of the specialized sigmas are used + sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - state.common.alphas_cumprod[0]) / state.common.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = jnp.concatenate([sigmas, jnp.array([sigma_last])]).astype(jnp.float32) + + model_outputs = jnp.zeros((self.config.solver_order, *shape), dtype=self.dtype) + timestep_list = jnp.zeros((self.config.solver_order,), dtype=jnp.int32) # Timesteps are integers + # Update the state with the new schedule and re-initialized history + return state.replace( + timesteps=timesteps, + sigmas=sigmas, + model_outputs=model_outputs, + timestep_list=timestep_list, + lower_order_nums=0, # Reset counters for a new inference run + step_index=None, + begin_index=None, + last_sample=None, + this_order=0, + ) + + def convert_model_output( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + Converts the model output based on the prediction type and current state. + """ + sigma = state.sigmas[state.step_index] # Current sigma + + # Ensure sigma is a JAX array for _sigma_to_alpha_sigma_t + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.config.predict_x0: + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + # Original code has `sigma_t = self.sigmas[self.step_index]`. + # This implies current sigma `sigma` is used as sigma_t for flow. + x0_pred = sample - sigma * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + raise NotImplementedError("Dynamic thresholding isn't implemented.") + # x0_pred = self._threshold_sample(x0_pred) + return x0_pred + else: # self.config.predict_x0 is False + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + order: int, + ) -> jnp.ndarray: + """ + One step for the UniP (B(h) version) - the Predictor. + """ + if self.config.solver_p: + raise NotImplementedError("Nested `solver_p` is not implemented in JAX version yet.") + + m0 = state.model_outputs[self.config.solver_order - 1] # Most recent stored converted model output + x = sample + + sigma_t_val, sigma_s0_val = ( + state.sigmas[state.step_index + 1], + state.sigmas[state.step_index], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) + + lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) + lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) + + h = lambda_t - lambda_s0 + + def rk_d1_loop_body(i, carry): + # Loop from i = 0 to order-2 + rks, D1s = carry + history_idx = self.config.solver_order - 2 - i + mi = state.model_outputs[history_idx] + si_val = state.timestep_list[history_idx] + + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(state.sigmas[self.index_for_timestep(state, si_val)]) + lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) + + rk = (lambda_si - lambda_s0) / h + Di = (mi - m0) / rk + + rks = rks.at[i].set(rk) + D1s = D1s.at[i].set(Di) + return rks, D1s + + rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) + if self.config.solver_order == 1: + # Dummy D1s array. It will not be used if order == 1 + D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) + rks, D1s = jax.lax.fori_loop(0, order - 1, rk_d1_loop_body, (rks_init, D1s_init)) + rks = rks.at[order - 1].set(1.0) + + hh = -h if self.config.predict_x0 else h + h_phi_1 = jnp.expm1(hh) + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = jnp.expm1(hh) + else: + raise NotImplementedError() + + def rb_loop_body(i, carry): + R, b, current_h_phi_k, factorial_val = carry + R = R.at[i].set(jnp.power(rks, i)) + b = b.at[i].set(current_h_phi_k * factorial_val / B_h) + + def update_fn(vals): + _h_phi_k, _fac = vals + next_fac = _fac * (i + 2) + next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac + return next_h_phi_k, next_fac + + current_h_phi_k, factorial_val = jax.lax.cond( + i < order - 1, + update_fn, + lambda vals: vals, + (current_h_phi_k, factorial_val), + ) + return R, b, current_h_phi_k, factorial_val + + R_init = jnp.zeros((self.config.solver_order, self.config.solver_order), dtype=h.dtype) + b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + init_h_phi_k = h_phi_1 / hh - 1.0 + init_factorial = 1.0 + R, b, _, _ = jax.lax.fori_loop(0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial)) + + if len(D1s) > 0: + D1s = jnp.stack(D1s, axis=1) # Resulting shape (B, K, C, H, W) + + def solve_for_rhos_p(R_mat, b_vec, current_order): + # Create a mask for the top-left (current_order - 1) x (current_order - 1) sub-matrix + mask_size = self.config.solver_order - 1 + mask = jnp.arange(mask_size) < (current_order - 1) + mask_2d = mask[:, None] & mask[None, :] + + # Pad R with identity and b with zeros for a safe solve + R_safe = jnp.where( + mask_2d, + R_mat[:mask_size, :mask_size], + jnp.eye(mask_size, dtype=R_mat.dtype), + ) + b_safe = jnp.where(mask, b_vec[:mask_size], 0.0) + + # Solve the system and mask the result + solved_rhos = jnp.linalg.solve(R_safe, b_safe) + return jnp.where(mask, solved_rhos, 0.0) + + # Handle the special case for order == 2 + if self.config.solver_order == 1: + # Dummy rhos_p_padded for tracing. + rhos_p_order2 = jnp.zeros(1, dtype=x.dtype) + else: + rhos_p_order2 = jnp.zeros(self.config.solver_order - 1, dtype=x.dtype).at[0].set(0.5) + + # Get the result for the general case + rhos_p_general = solve_for_rhos_p(R, b, order) + + # Select the appropriate result based on the order + rhos_p = jnp.where(order == 2, rhos_p_order2, rhos_p_general) + + pred_res = jax.lax.cond( + order > 1, + lambda _: jnp.einsum("k,bkc...->bc...", rhos_p, D1s).astype(x.dtype), + # False branch: return a zero tensor with the correct shape. + lambda _: jnp.zeros_like(x), + operand=None, + ) + + if self.config.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: # Predict epsilon + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + x_t = x_t_ - sigma_t * B_h * pred_res + + return x_t.astype(x.dtype) + + def multistep_uni_c_bh_update( + self, + state: UniPCMultistepSchedulerState, + this_model_output: jnp.ndarray, + last_sample: jnp.ndarray, # Sample after predictor `x_{t-1}` + this_sample: jnp.ndarray, # Sample before corrector `x_t` (after predictor step) + order: int, + ) -> jnp.ndarray: + """ + One step for the UniC (B(h) version) - the Corrector. + """ + model_output_list = state.model_outputs + m0 = model_output_list[self.config.solver_order - 1] # Most recent model output from history + + if last_sample is not None: + x = last_sample + else: + # If it's None, create dummy data. This is for the tracing purpose + x = jnp.zeros_like(this_sample) + + x_t = this_sample + + model_t = this_model_output + + sigma_t_val = state.sigmas[state.step_index] + sigma_s0_val = state.sigmas[state.step_index - 1] + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val) + + lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10) + lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10) + + h = lambda_t - lambda_s0 + + def rk_d1_loop_body(i, carry): + # Loop from i = 0 to order-1. + rks, D1s = carry + + # Get history from state buffer + history_idx = self.config.solver_order - (i + 2) + mi = state.model_outputs[history_idx] + si_val = state.timestep_list[history_idx] # This is the actual timestep value + + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(state.sigmas[self.index_for_timestep(state, si_val)]) + lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10) + + rk = (lambda_si - lambda_s0) / h + Di = (mi - m0) / rk + + # Update pre-allocated arrays + rks = rks.at[i].set(rk) + D1s = D1s.at[i].set(Di) + return rks, D1s + + # Pre-allocate arrays to max possible size + rks_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + D1s_init = jnp.zeros((self.config.solver_order - 1, *m0.shape), dtype=m0.dtype) + if self.config.solver_order == 1: + # Dummy D1s array. It will not be used if order == 1. This is for tracing. + D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype) + + # Run the loop up to `order - 1` + rks, D1s = jax.lax.fori_loop(0, order - 1, rk_d1_loop_body, (rks_init, D1s_init)) + + rks = rks.at[order - 1].set(1.0) + + hh = -h if self.config.predict_x0 else h + h_phi_1 = jnp.expm1(hh) + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = jnp.expm1(hh) + else: + raise NotImplementedError() + + def rb_loop_body(i, carry): + # Loop from i = 0 to order-1 + R, b, current_h_phi_k, factorial_val = carry + + R = R.at[i].set(jnp.power(rks, i)) + b = b.at[i].set(current_h_phi_k * factorial_val / B_h) + + # Conditionally update phi_k and factorial for the next iteration + def update_fn(vals): + # This branch is taken if i < order - 1 + _h_phi_k, _fac = vals + next_fac = _fac * (i + 2) + next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac + return next_h_phi_k, next_fac + + current_h_phi_k, factorial_val = jax.lax.cond( + i < order - 1, + update_fn, # If true, update values + lambda vals: vals, # If false, pass through + (current_h_phi_k, factorial_val), + ) + return R, b, current_h_phi_k, factorial_val + + # Pre-allocate R and b to max size + R_init = jnp.zeros((self.config.solver_order, self.config.solver_order), dtype=h.dtype) + b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype) + + # Initialize loop carriers + init_h_phi_k = h_phi_1 / hh - 1.0 + init_factorial = 1.0 + + R, b, _, _ = jax.lax.fori_loop(0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial)) + + if len(D1s) > 0: + D1s = jnp.stack(D1s, axis=1) # (B, K, C, H, W) + + def solve_for_rhos(R_mat, b_vec, current_order): + # Create a mask to select the first `current_order` elements + mask = jnp.arange(self.config.solver_order) < current_order + mask_2d = mask[:, None] & mask[None, :] + + # Pad R with identity and b with zeros to create a safe, full-sized system + R_safe = jnp.where(mask_2d, R_mat, jnp.eye(self.config.solver_order, dtype=R_mat.dtype)) + b_safe = jnp.where(mask, b_vec, 0.0) + + # Solve the full-size system and mask the result + solved_rhos = jnp.linalg.solve(R_safe, b_safe) + return jnp.where(mask, solved_rhos, 0.0) + + rhos_c_order1 = jnp.zeros(self.config.solver_order, dtype=x_t.dtype).at[0].set(0.5) + rhos_c_general = solve_for_rhos(R, b, order) + rhos_c = jnp.where(order == 1, rhos_c_order1, rhos_c_general) + + D1_t = model_t - m0 + + corr_res = jax.lax.cond( + order > 1, + lambda _: (jnp.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)), + lambda _: jnp.zeros_like(D1_t), + operand=None, + ) + + final_rho = jnp.dot( + rhos_c, + jax.nn.one_hot(order - 1, self.config.solver_order, dtype=rhos_c.dtype), + ) + + if self.config.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + final_rho * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + final_rho * D1_t) + + return x_t.astype(x.dtype) + + def index_for_timestep( + self, + state: UniPCMultistepSchedulerState, + timestep: Union[int, jnp.ndarray], + schedule_timesteps: Optional[jnp.ndarray] = None, + ) -> int: + """ "Gets the step_index for timestep.""" + if schedule_timesteps is None: + schedule_timesteps = state.timesteps + + # QUINN!! + # timestep_val = ( + # timestep.item() + # if isinstance(timestep, jnp.ndarray) and timestep.ndim == 0 + # else timestep + # ) + timestep_val = timestep + + index_candidates = jnp.where(schedule_timesteps == timestep_val, size=1, fill_value=-1)[0] + + step_index = jnp.where( + index_candidates[0] == -1, # No match found + len(schedule_timesteps) - 1, # Default to last index + index_candidates[0], + ) + return step_index + + def _init_step_index( + self, state: UniPCMultistepSchedulerState, timestep: Union[int, jnp.ndarray] + ) -> UniPCMultistepSchedulerState: + """Initializes the step_index counter for the scheduler.""" + if state.begin_index is None: + step_index_val = self.index_for_timestep(state, timestep) + return state.replace(step_index=step_index_val) + else: + return state.replace(step_index=state.begin_index) + + @partial(jax.jit, static_argnums=(0, 5)) # self is static_argnum=0 + def step( + self, + state: UniPCMultistepSchedulerState, + model_output: jnp.ndarray, # This is the direct output from the diffusion model (e.g., noise prediction) + timestep: Union[int, jnp.ndarray], # Current discrete timestep from the scheduler's sequence + sample: jnp.ndarray, # Current noisy sample (latent) + return_dict: bool = True, + generator: Optional[jax.random.PRNGKey] = None, # JAX random key + ): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + """ + + sample = sample.astype(jnp.float32) + + if state.timesteps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep_scalar = jnp.array(timestep) + + # Initialize step_index if it's the first step + if state.step_index is None: + state = self._init_step_index(state, timestep_scalar) + + # Determine if corrector should be used + use_corrector = ( + (state.step_index > 0) + & (~jnp.isin(state.step_index - 1, jnp.array(self.config.disable_corrector))) + & (state.last_sample is not None) + ) + + # Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type + model_output_for_history = self.convert_model_output(state, model_output, sample) + + # Apply corrector if applicable + sample = jax.lax.cond( + use_corrector, + lambda: self.multistep_uni_c_bh_update( + state=state, + this_model_output=model_output_for_history, + last_sample=state.last_sample, + this_sample=sample, + order=state.this_order, + ), + lambda: sample, + ) + + # Update history buffers (model_outputs and timestep_list) + # Shift existing elements to the left and add new one at the end. + # `state.model_outputs` and `state.timestep_list` are fixed-size arrays. + # Example: + # t0:[None,...,model_output0] + # t1:[None,..model_output0,model_output1] + # ... + # tn:[model_output0,model_output1,...,model_output_n] + def step_idx0_branch(): + updated_model_outputs_history = state.model_outputs.at[-1].set(model_output_for_history) + updated_timestep_list_history = state.timestep_list.at[-1].set(timestep_scalar) + return updated_model_outputs_history, updated_timestep_list_history + + def non_step_idx0_branch(): + updated_model_outputs_history = jnp.roll(state.model_outputs, shift=-1, axis=0) + updated_model_outputs_history = updated_model_outputs_history.at[-1].set(model_output_for_history) + + updated_timestep_list_history = jnp.roll(state.timestep_list, shift=-1) + updated_timestep_list_history = updated_timestep_list_history.at[-1].set(timestep_scalar) + return updated_model_outputs_history, updated_timestep_list_history + + updated_model_outputs_history, updated_timestep_list_history = jax.lax.cond( + state.step_index == 0, step_idx0_branch, non_step_idx0_branch + ) + state = state.replace( + model_outputs=updated_model_outputs_history, + timestep_list=updated_timestep_list_history, + ) + + # Determine the order for the current step (warmup phase logic) + this_order = jnp.where( + self.config.lower_order_final, + jnp.minimum(self.config.solver_order, len(state.timesteps) - state.step_index), + self.config.solver_order, + ) + + # Warmup for multistep: `this_order` can't exceed `lower_order_nums + 1` + new_this_order = jnp.minimum(this_order, state.lower_order_nums + 1) + state = state.replace(this_order=new_this_order) + + # Store current sample as `last_sample` for the *next* step's corrector + state = state.replace(last_sample=sample) + + # UniP predictor step + prev_sample = self.multistep_uni_p_bh_update( + state=state, + model_output=model_output, + sample=sample, + order=state.this_order, + ) + + # Update lower_order_nums for warmup + new_lower_order_nums = jnp.where( + state.lower_order_nums < self.config.solver_order, + state.lower_order_nums + 1, + state.lower_order_nums, + ) + state = state.replace(lower_order_nums=new_lower_order_nums) + # Upon completion, increase step index by one + state = state.replace(step_index=state.step_index + 1) + + # Return the updated sample and state + if not return_dict: + return (prev_sample, state) + + return prev_sample, state + + def scale_model_input( + self, state: UniPCMultistepSchedulerState, sample: jnp.ndarray, *args, **kwargs + ) -> jnp.ndarray: + """ + UniPC does not scale model input, so it returns the sample unchanged. + """ + return sample + + def add_noise( + self, + state: UniPCMultistepSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + return add_noise_common(state.common, original_samples, noise, timesteps) + + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = sigma + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + def __len__(self) -> int: + return self.config.num_train_timesteps diff --git a/bonsai/models/wan2/vae_wan.py b/bonsai/models/wan2/vae_wan.py new file mode 100644 index 00000000..d1d9ba8b --- /dev/null +++ b/bonsai/models/wan2/vae_wan.py @@ -0,0 +1,738 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wan-VAE: Video Variational Autoencoder for Wan2.1-T2V-1.3B. + +This module provides a JAX/Flax implementation of the Wan-VAE decoder, +which converts latent representations to RGB video frames. + +Architecture (based on reference implementation): +- Denormalization with learned mean/std +- Frame-by-frame decoding with temporal upsampling +- CausalConv3d for temporal coherence +- Spatial upsampling from 60x60 to 832x480 +- Temporal upsampling from 21 to 81 frames +""" + +from dataclasses import dataclass +from typing import Tuple + +import imageio +import jax +import jax.numpy as jnp +from flax import nnx +from jax.lax import Precision +from jaxtyping import Array, Union + +CACHE_T = 2 + + +@dataclass +class VAEConfig: + """Configuration for Wan-VAE decoder. + + Latent denormalization constants from reference implementation. + These are fixed constants computed during VAE training. + """ + + latent_mean: Tuple[float, ...] = ( + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ) + + latent_std: Tuple[float, ...] = ( + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ) + + +class CausalConv3d(nnx.Module): + """Causal 3D convolution that doesn't look into the future.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int, int] = (3, 3, 3), + *, + rngs: nnx.Rngs, + padding: Tuple[int, int, int] = (0, 0, 0), + ): + self.kernel_size = kernel_size + self.temporal_padding = padding[0] # Save for cache size calculation + self.conv = nnx.Conv( + in_features=in_channels, + out_features=out_channels, + kernel_size=kernel_size, + padding="VALID", # We'll handle padding manually + rngs=rngs, + precision=Precision.HIGHEST, + ) + self.padding = ( + (0, 0), + (2 * padding[0], 0), + (padding[1], padding[1]), + (padding[2], padding[2]), + (0, 0), + ) + + def __call__(self, x: Array, cache: Array | None = None) -> tuple[Array, Array | None]: + """Forward pass with optional caching. + + Args: + x: [B, T, H, W, C] input (JAX channel-last format) + cache: [B, CACHE_T, H, W, C] cached frames from previous call, or None + + Returns: + out: [B, T_out, H_out, W_out, C_out] output + new_cache: [B, CACHE_T, H, W, C] cache for next call, or None + """ + # Cache size is 2*padding because we pad left by (2*padding, 0) for causality + cache_t = 2 * self.temporal_padding + if cache is not None and cache_t > 0: + x = jnp.concatenate([cache, x], axis=1) # [B, T+CACHE_T, H, W, C] + if self.conv.in_features != self.conv.out_features: + jax.debug.print("CausalConv3d with cache input has nan:{}", jnp.isnan(x).any(), ordered=True) + jax.debug.print("feat cache in causalconv3d:{},{}", cache.shape, x.shape, ordered=True) + padding = list(self.padding) + padding[1] = (max(0, self.padding[1][0] - cache.shape[1]), 0) # Reduce left padding + padding = tuple(padding) + else: + padding = self.padding + + x_padded = jnp.pad(x, padding, mode="constant") + if self.conv.in_features != self.conv.out_features: + jax.debug.print("CausalConv3d input has nan:{}", jnp.isnan(x_padded).any(), ordered=True) + out = self.conv(x_padded) + if self.conv.in_features != self.conv.out_features: + jax.debug.print("CausalConv3d output has nan:{}", jnp.isnan(out).any(), ordered=True) + + # Extract cache for next iteration: last cache_t frames of INPUT (before conv) + # Always create cache if we have temporal padding (even on first frame) + if cache_t > 0: + new_cache = x[:, -cache_t:, :, :, :] # [B, <=CACHE_T, H, W, C] + # Pad on the left if we do not yet have cache_t frames (e.g., first call with T=1). + if new_cache.shape[1] < cache_t: + pad_t = cache_t - new_cache.shape[1] + new_cache = jnp.pad(new_cache, ((0, 0), (pad_t, 0), (0, 0), (0, 0), (0, 0)), mode="constant") + else: + new_cache = None + + return out, new_cache + + +class RMSNorm(nnx.Module): + """RMS Normalization with L2 normalize and learned scale. + + Based on F.normalize approach: normalize to unit norm, then scale. + For videos (images=False), uses 3D spatial+temporal normalization. + """ + + def __init__(self, dim: int, *, rngs: nnx.Rngs): + self.scale_factor = dim**0.5 + # gamma shape: (dim,) will broadcast to [B, T, H, W, C] or [B, H, W, C] + self.scale = nnx.Param(jnp.ones(dim)) + self.eps = 1e-12 + + def __call__(self, x: Array) -> Array: + # x: [B, T, H, W, C] for 3D or [B, H, W, C] for 2D + # Normalize to unit RMS along the channel dimension manually since jax.nn.normalize is unavailable. + rms = jnp.sqrt(jnp.sum(jnp.square(x), axis=-1, keepdims=True) + self.eps) + # if jnp.isnan(x).any(): + # nan_mask_x = jnp.isnan(x) + # nan_indices_x = jnp.argwhere(nan_mask_x, size=5, fill_value=-1) + x_normalized = x / rms + # nan_mask = jnp.isnan(x_normalized) + # nan_indices = jnp.argwhere(nan_mask, size=5, fill_value=-1) + # jax.debug.print( + # "x_normalized NaN at indices?:{}, nan indices:{}, rum values:{}, x NaN?: {}, x NaN indices:{}, ", + # nan_mask.any(), + # nan_indices[:5], + # rms.mean(), + # nan_mask_x.any(), + # nan_indices_x[:5], + # ) + # jax.debug.print("x_normalized has nan: {}", jnp.isnan(x_normalized).any()) + # jax.debug.print("scale values: {} {}", self.scale_factor, self.scale.value.mean()) + return x_normalized * self.scale_factor * self.scale.value + + +class ResidualBlock(nnx.Module): + """Residual block with RMSNorm and SiLU activation.""" + + def __init__(self, in_channels: int, out_channels: int, *, rngs: nnx.Rngs): + self.norm1 = RMSNorm(in_channels, rngs=rngs) + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1), rngs=rngs) + self.norm2 = RMSNorm(out_channels, rngs=rngs) + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1), rngs=rngs) + + if in_channels != out_channels: + self.skip_conv = CausalConv3d(in_channels, out_channels, kernel_size=(1, 1, 1), rngs=rngs) + else: + self.skip_conv = None + + def __call__( + self, x: Array, cache_list: tuple[Array | None, ...] | None = None, cache_idx: list[int] | None = None + ) -> tuple[Array, tuple[Array | None, ...] | None]: + residual = x + x = self.norm1(x) + # if self.skip_conv is not None: + # jax.debug.print("Residual block norm1 output has nan:{}", jnp.isnan(x).any(), ordered=True) + x = nnx.silu(x) + # if self.skip_conv is not None: + # jax.debug.print("Residual block activation output has nan:{}", jnp.isnan(x).any(), ordered=True) + + if cache_list is not None: + idx = cache_idx[0] + x, new_cache = self.conv1(x, cache_list[idx]) + cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :]) + cache_idx[0] += 1 + # if self.skip_conv is not None: + # jax.debug.print("Residual block conv1 output has nan:{}", jnp.isnan(x).any(), ordered=True) + else: + x, _ = self.conv1(x, None) + # if self.skip_conv is not None: + # jax.debug.print("no cache: Residual block conv1 output has nan:{}", jnp.isnan(x).any(), ordered=True) + + x = self.norm2(x) + # if self.skip_conv is not None: + # jax.debug.print("Residual block norm2 output has nan:{}", jnp.isnan(x).any(), ordered=True) + x = nnx.silu(x) + # if self.skip_conv is not None: + # jax.debug.print("Residual block activation2 output has nan:{}", jnp.isnan(x).any(), ordered=True) + + if cache_list is not None: + idx = cache_idx[0] + x, new_cache = self.conv2(x, cache_list[idx]) + cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :]) + cache_idx[0] += 1 + # if self.skip_conv is not None: + # jax.debug.print("Residual block conv2 output has nan:{}", jnp.isnan(x).any(), ordered=True) + else: + x, _ = self.conv2(x, None) + # if self.skip_conv is not None: + # jax.debug.print("no cache: Residual block conv2 output has nan:{}", jnp.isnan(x).any(), ordered=True) + + if self.skip_conv is not None: + residual, _ = self.skip_conv(residual, None) + # jax.debug.print("Residual conv output has nan:{}", jnp.isnan(residual).any(), ordered=True) + + # jax.debug.print("Residual block output has nan:{}", jnp.isnan(x).any(), ordered=True) + + return x + residual, cache_list + + +class AttentionBlock(nnx.Module): + """Spatial attention block with batched frame processing.""" + + def __init__(self, channels: int, *, rngs: nnx.Rngs): + self.norm = RMSNorm(channels, rngs=rngs) + self.qkv = nnx.Conv( + in_features=channels, + out_features=channels * 3, + kernel_size=(1, 1), + use_bias=True, + rngs=rngs, + precision=Precision.HIGHEST, + ) + self.proj = nnx.Conv( + in_features=channels, + out_features=channels, + kernel_size=(1, 1), + use_bias=True, + rngs=rngs, + precision=Precision.HIGHEST, + ) + + def __call__(self, x: Array) -> Array: + # x: [B, T, H, W, C] + b, t, h, w, c = x.shape + residual = x + + x = x.reshape(b * t, h, w, c) + x = self.norm(x) + # QKV projection: [B*T, H, W, C] -> [B*T, H, W, 3*C] + qkv = self.qkv(x) + + # Reshape for attention: [B*T, H, W, 3*C] -> [B*T, H*W, 3*C] -> split to Q, K, V + qkv = qkv.reshape(b * t, h * w, 3 * c) + q, k, v = jnp.split(qkv, 3, axis=-1) # Each: [B*T, H*W, C] + + # Scaled dot-product attention + scale = c**-0.5 + attn = jax.nn.softmax(jnp.einsum("bic,bjc->bij", q, k) * scale, axis=-1) # [B*T, H*W, H*W] + out = jnp.einsum("bij,bjc->bic", attn, v) # [B*T, H*W, C] + + # Reshape back to spatial: [B*T, H*W, C] -> [B*T, H, W, C] + out = out.reshape(b * t, h, w, c) + + # Output projection + out = self.proj(out) + + # Reshape back to video: [B*T, H, W, C] -> [B, T, H, W, C] + out = out.reshape(b, t, h, w, c) + + return out + residual + + +class Upsample2D(nnx.Module): + """Spatial 2x upsample that also halves channels, mirroring torch Resample.""" + + def __init__(self, in_channels: int, out_channels: int, *, rngs: nnx.Rngs): + self.in_channels = in_channels + self.out_channels = out_channels + self.conv = nnx.Conv( + in_features=in_channels, + out_features=out_channels, + kernel_size=(3, 3), + padding=1, + rngs=rngs, + precision=Precision.HIGHEST, + ) + + def __call__(self, x: Array) -> Array: + # x: [B, T, H, W, Cin] + b, t, h, w, _ = x.shape + orig_dtype = x.dtype + x = x.reshape(b * t, h, w, self.in_channels) + x = jax.image.resize(x.astype(jnp.float32), (b * t, h * 2, w * 2, self.in_channels), method="nearest").astype( + orig_dtype + ) + x = self.conv(x) + return x.reshape(b, t, h * 2, w * 2, self.out_channels) + + +class Upsample3D(nnx.Module): + """Temporal+spatial 2x upsample with channel reduction (like torch Resample).""" + + def __init__(self, in_channels: int, out_channels: int, *, rngs: nnx.Rngs): + self.in_channels = in_channels + self.out_channels = out_channels + self.time_conv = CausalConv3d(in_channels, in_channels * 2, kernel_size=(3, 1, 1), padding=(1, 0, 0), rngs=rngs) + self.spatial_conv = nnx.Conv( + in_features=in_channels, + out_features=out_channels, + kernel_size=(3, 3), + padding=1, + rngs=rngs, + precision=Precision.HIGHEST, + ) + + def __call__( + self, x: Array, cache_list: tuple[Array | None, ...] | None = None, cache_idx: list[int] | None = None + ) -> tuple[Array, tuple[Array | None, ...] | None]: + b, t, h, w, _ = x.shape + orig_dtype = x.dtype + + if cache_list is not None: + idx = cache_idx[0] + + # First frame: skip time_conv, only do spatial upsampling + if cache_list[idx] is None: + # Use zero array as sentinel with SAME shape as real cache + # This ensures consistent pytree structure for JIT + # We use zeros with shape [B, 2, H, W, C] where 2 = cache size for 3x1x1 conv + sentinel = jnp.zeros((b, 2, h, w, self.in_channels), dtype=orig_dtype) + cache_list = (*cache_list[:idx], sentinel, *cache_list[idx + 1 :]) + cache_idx[0] += 1 + t_out = t + else: + # Always pass the cached features (including the zero sentinel) so the + # time_conv sees a length-2 cache and returns a length-2 cache, matching + # the torch behavior where the sentinel seeds the cache. + x, new_cache = self.time_conv(x, cache_list[idx]) + + cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :]) + cache_idx[0] += 1 + + x = x.reshape(b, t, h, w, 2, self.in_channels) + x = jnp.moveaxis(x, 4, 2) # [B, T, 2, H, W, Cin] -> [B, 2, T, H, W, Cin] + t_out = t * 2 + x = x.reshape(b, t_out, h, w, self.in_channels) + + # Spatial upsampling (always applied) + bt = b * t_out + x = x.reshape(bt, h, w, self.in_channels) + x = jax.image.resize(x.astype(jnp.float32), (bt, h * 2, w * 2, self.in_channels), method="nearest").astype( + orig_dtype + ) + x = self.spatial_conv(x) + return x.reshape(b, t_out, h * 2, w * 2, self.out_channels), cache_list + + +class Decoder3D(nnx.Module): + """ + 3D Decoder matching reference implementation. + Upsamples from [B, 1, 104, 60, 16] -> [B, 4, 832, 480, 3] (JAX format) + """ + + def __init__(self, *, rngs: nnx.Rngs): + # Initial convolution: 16 -> 384 + self.conv_in = CausalConv3d(16, 384, kernel_size=(3, 3, 3), rngs=rngs, padding=(1, 1, 1)) + + # Middle blocks (at lowest resolution) + self.mid_block1 = ResidualBlock(384, 384, rngs=rngs) + self.mid_attn = AttentionBlock(384, rngs=rngs) + self.mid_block2 = ResidualBlock(384, 384, rngs=rngs) + + # Upsample stages (match torch checkpoint shapes) + # Stage 0: stay at 384, then upsample to 192 channels + self.up_blocks_0 = nnx.List( + [ + ResidualBlock(384, 384, rngs=rngs), + ResidualBlock(384, 384, rngs=rngs), + ResidualBlock(384, 384, rngs=rngs), + ] + ) + self.up_sample_0 = Upsample3D(384, 192, rngs=rngs) + + # Stage 1: 192 -> 384 (first block), remain at 384, then upsample to 192 + self.up_blocks_1 = nnx.List( + [ + ResidualBlock(192, 384, rngs=rngs), + ResidualBlock(384, 384, rngs=rngs), + ResidualBlock(384, 384, rngs=rngs), + ] + ) + self.up_sample_1 = Upsample3D(384, 192, rngs=rngs) + + # Stage 2: stay at 192, then spatial-only upsample to 96 + self.up_blocks_2 = nnx.List( + [ + ResidualBlock(192, 192, rngs=rngs), + ResidualBlock(192, 192, rngs=rngs), + ResidualBlock(192, 192, rngs=rngs), + ] + ) + self.up_sample_2 = Upsample2D(192, 96, rngs=rngs) + + # Stage 3: 96 -> 96, no upsample + self.up_blocks_3 = nnx.List( + [ + ResidualBlock(96, 96, rngs=rngs), + ResidualBlock(96, 96, rngs=rngs), + ResidualBlock(96, 96, rngs=rngs), + ] + ) + + # Output head: 96 -> 3 + self.norm_out = RMSNorm(96, rngs=rngs) + self.conv_out = CausalConv3d(96, 3, kernel_size=(3, 3, 3), padding=(1, 1, 1), rngs=rngs) + + def __call__( + self, z: Array, cache_list: tuple[Array | None, ...] | None = None, cache_idx: list[int] | None = None + ) -> tuple[Array, tuple[Array | None, ...] | None]: + """Forward pass with optional caching. + + Args: + z: [B, T, H, W, C] latent (e.g., [1, 1, 104, 60, 16]) + cache_list: Tuple of cached features for all conv layers, or None + cache_idx: List containing current index in cache_list (mutable), or None + + Returns: + x: [B, T_out, H_out, W_out, 3] RGB video (e.g., [1, 4, 832, 480, 3]) + cache_list: Updated cache tuple + """ + # Initial convolution + if cache_list is not None: + idx = cache_idx[0] + x, new_cache = self.conv_in(z, cache_list[idx]) + cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :]) + cache_idx[0] += 1 + else: + x, _ = self.conv_in(z, None) + + # jax.debug.print("Decoder3D conv_in output has nan:{}", jnp.isnan(x).any(), ordered=True) + + # Middle blocks + x, cache_list = self.mid_block1(x, cache_list, cache_idx) + x = self.mid_attn(x) # Attention doesn't use cache + x, cache_list = self.mid_block2(x, cache_list, cache_idx) + + # jax.debug.print("Decoder3D mid output has nan:{}", jnp.isnan(x).any(), ordered=True) + + # Upsample stage 0 + for block in self.up_blocks_0: + x, cache_list = block(x, cache_list, cache_idx) + x, cache_list = self.up_sample_0(x, cache_list, cache_idx) + + # jax.debug.print("Decoder3D upsample0 output has nan:{}", jnp.isnan(x).any(), ordered=True) + + # Upsample stage 1 + for block in self.up_blocks_1: + x, cache_list = block(x, cache_list, cache_idx) + x, cache_list = self.up_sample_1(x, cache_list, cache_idx) + + # jax.debug.print("Decoder3D upsample1 output has nan:{}", jnp.isnan(x).any(), ordered=True) + + # Upsample stage 2 + for block in self.up_blocks_2: + x, cache_list = block(x, cache_list, cache_idx) + x = self.up_sample_2(x) # Spatial-only upsample, no cache + + # jax.debug.print("Decoder3D upsample2 output has nan:{}", jnp.isnan(x).any(), ordered=True) + + # Upsample stage 3 (no spatial upsample) + for block in self.up_blocks_3: + x, cache_list = block(x, cache_list, cache_idx) + + # jax.debug.print("Decoder3D upsample3 output has nan:{}", jnp.isnan(x).any(), ordered=True) + + # Output + x = self.norm_out(x) + x = nnx.silu(x) + + # jax.debug.print("Decoder3D norm_out output has nan:{}", jnp.isnan(x).any(), ordered=True) + + if cache_list is not None: + idx = cache_idx[0] + x, new_cache = self.conv_out(x, cache_list[idx]) + cache_list = (*cache_list[:idx], new_cache, *cache_list[idx + 1 :]) + cache_idx[0] += 1 + else: + x, _ = self.conv_out(x, None) + + # jax.debug.print("Decoder3D conv_out output has nan:{}", jnp.isnan(x).any(), ordered=True) + + return x, cache_list + + +class WanVAEDecoder(nnx.Module): + """ + Wan-VAE Decoder: Converts video latents to RGB frames. + + Architecture matches reference (wan/modules/vae.py:544-568): + 1. Denormalize latents with learned mean/std + 2. Conv 1x1 projection (16 -> 16 channels) + 3. Frame-by-frame decode with Decoder3D + 4. Concatenate and clamp output + + Input: [B, T, H, W, C] = [1, 21, 104, 60, 16] + Output: [B, T_out, H_out, W_out, 3] = [1, 81, 832, 480, 3] + """ + + def __init__(self, cfg: VAEConfig = VAEConfig(), *, rngs: nnx.Rngs): + # Store config tuples as Python values (not JAX arrays!) + # They'll be converted to JAX arrays at runtime in decode() + # This avoids ShapeDtypeStruct issues during nnx.eval_shape() + self.latent_mean_tuple = cfg.latent_mean + self.latent_std_tuple = cfg.latent_std + + # 1x1 conv projection + self.conv2 = CausalConv3d(16, 16, kernel_size=(1, 1, 1), rngs=rngs) + + # 3D decoder + self.decoder = Decoder3D(rngs=rngs) + + def decode(self, latents: Array) -> Array: + """ + Decode latents to RGB video with feature caching. + + Args: + latents: [B, T, H, W, C] latent representation (JAX format) + e.g., [1, 21, 104, 60, 16] + + Returns: + video: [B, T_out, H_out, W_out, 3] RGB video (values in [-1, 1]) + e.g., [1, 81, 832, 480, 3] + """ + # Step 1: Denormalize + # Convert Python tuples to JAX arrays at runtime (JIT treats them as static constants) + latent_mean = jnp.array(self.latent_mean_tuple).reshape(1, 1, 1, 1, 16) + latent_std = jnp.array(self.latent_std_tuple).reshape(1, 1, 1, 1, 16) + z = latents * latent_std + latent_mean + + z, _ = self.conv2(z, None) + + # Scan over time dimension: z is [B, T, H, W, C], transpose to [T, B, H, W, C] + z_frames = jnp.moveaxis(z, 1, 0) # [T, B, H, W, C] + # Add singleton time dimension for each frame: [T, B, 1, H, W, C] + z_frames = z_frames[:, :, None, :, :, :] + + # jax.debug.print("z_frames has nan:{}", jnp.isnan(z_frames).any()) + + # Warm-up pass: process first frame to initialize cache with correct shapes + # This ensures consistent pytree structure for jax.lax.scan + cache_idx = [0] + cache_tuple = (None,) * 50 + first_frame_out, cache_tuple = self.decoder(z_frames[0], cache_tuple, cache_idx) + num_arrays = sum(isinstance(x, jnp.ndarray) for x in cache_tuple) + num_nones = sum(x is None for x in cache_tuple) + print(f"cache Arrays: {num_arrays},cache Nones: {num_nones}") + + # JIT-compiled scan function for remaining frames (now cache has concrete shapes) + @jax.jit + def scan_frames(cache_tuple, frame_latent): + """Process single frame with caching (JIT-compiled).""" + cache_idx = [0] + frame_out, new_cache_tuple = self.decoder(frame_latent, cache_tuple, cache_idx) + # num_arrays = sum(isinstance(x, jnp.ndarray) for x in new_cache_tuple) + # num_nones = sum(x is None for x in new_cache_tuple) + # jax.debug.print("new cache Arrays: {},cache Nones: {}", num_arrays, num_nones) + # jax.debug.print("frame_out shape:{}", frame_out.shape) + # right_part_frame = frame_out[:, :, :, 235:, :] + # jax.debug.print("frame_out Has NaN: {}", jnp.isnan(right_part_frame).any()) + return new_cache_tuple, frame_out + + # Process remaining frames with JIT + if z_frames.shape[0] > 1: + with jax.disable_jit(): + _final_cache, remaining_outputs = jax.lax.scan(scan_frames, cache_tuple, z_frames[1:]) + + print(f"remaining output shape: {remaining_outputs.shape}") + right_part_remaining = remaining_outputs[:, :, :, :, 235:, :] + print(f"Has NaN: {jnp.isnan(right_part_remaining).any()}") + print(f"Has Inf: {jnp.isinf(right_part_remaining).any()}") + print(f"remaining output mean:{right_part_remaining.mean()} ") + # Frame 0 outputs 1 frame: [B, 1, H, W, 3] + # Frames 1+ each output 4 frames: [T-1, B, 4, H, W, 3] + # Flatten temporal dimensions before concatenating + b, h_out, w_out, c = ( + first_frame_out.shape[0], + first_frame_out.shape[2], + first_frame_out.shape[3], + first_frame_out.shape[4], + ) + + # Flatten first frame: [B, 1, H, W, 3] -> [1, B, H, W, 3] + first_flat = first_frame_out.transpose(1, 0, 2, 3, 4) # [1, B, H, W, 3] + + # Flatten remaining frames: [T-1, B, 4, H, W, 3] -> [T-1*4, B, H, W, 3] + t_minus_1 = remaining_outputs.shape[0] + t_out_per_frame = remaining_outputs.shape[2] + remaining_flat = remaining_outputs.transpose(0, 2, 1, 3, 4, 5).reshape( + t_minus_1 * t_out_per_frame, b, h_out, w_out, c + ) + print(f"remaining flat shape:{remaining_flat.shape}") + print(f"remaining flat mean:{remaining_flat[:, 0, :, 235:, :].mean()} ") + + # Concatenate along time dimension: [1+T-1*4, B, H, W, 3] + # Concatenate first frame with remaining frames + x = jnp.concatenate([first_flat, remaining_flat], axis=0).transpose(1, 0, 2, 3, 4) + else: + x = first_frame_out + + # Clamp to [-1, 1] + x = jnp.clip(x, -1.0, 1.0) + + return x + + +def load_vae_from_checkpoint(checkpoint_path: str, rngs: nnx.Rngs) -> WanVAEDecoder: + """ + Load Wan-VAE decoder from checkpoint. + + Args: + checkpoint_path: Path to the VAE checkpoint directory + rngs: Random number generators for initialization + + Returns: + WanVAEDecoder with loaded weights + """ + # Create VAE decoder structure + vae_decoder = WanVAEDecoder(rngs=rngs) + + # TODO: Implement checkpoint loading + # This will require mapping PyTorch VAE weights to JAX + # Similar to what's done in params.py for the DiT model + + print("Warning: VAE checkpoint loading not yet implemented") + print("Returning VAE with random weights") + + return vae_decoder + + +def decode_latents_to_video(vae_decoder: WanVAEDecoder, latents: Array, normalize: bool = True) -> Array: + """ + Helper function to decode latents and post-process to video. + + Args: + vae_decoder: WanVAEDecoder instance + latents: [B, T, H, W, C] latent representation + normalize: If True, normalize output from [-1, 1] to [0, 255] uint8 + + Returns: + video: [B, T, H_out, W_out, 3] video frames + """ + # Decode + video = vae_decoder.decode(latents) + print(f"video shape:{video.shape}") + print(f"video mean:{video[0, 1:, :, 235:, :].mean()}") + + if normalize: + video = (video + 1.0) / 2.0 + video = jnp.clip(video, 0.0, 1.0) + + video = jnp.round(video * 255.0) + video = jnp.clip(video, 0, 255).astype(jnp.uint8) + + return video + + +def save_video( + video: Array, + save_path: str, + fps: int = 30, + codec: str = "libx264", + quality: int = 8, +) -> str | None: + try: + # Handle batch dimension: take first video if batched + assert video.ndim == 5 + video = video[0] # [T, H, W, C] + + video_np = jax.device_get(video) + + # Write video + writer = imageio.get_writer(save_path, fps=fps, codec=codec, quality=quality) + for frame in video_np: + writer.append_data(frame) + writer.close() + + return save_path + + except Exception as e: + print(f"Failed to save video: {e}") + return None + + +__all__ = ["VAEConfig", "WanVAEDecoder", "decode_latents_to_video", "load_vae_from_checkpoint", "save_video"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..9f89b3a2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,133 @@ +absl-py==2.3.1 +accelerate==1.12.0 +aiofiles==24.1.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.13.2 +aiosignal==1.4.0 +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.11.0 +attrs==25.4.0 +brotli==1.2.0 +certifi==2025.11.12 +cffi==2.0.0 +charset-normalizer==3.4.4 +chex==0.1.91 +click==8.3.1 +cryptography==46.0.3 +dashscope==1.25.2 +diffusers==0.35.2 +easydict==1.13 +einops==0.8.1 +etils==1.13.0 +fastapi==0.122.0 +ffmpy==1.0.0 +filelock==3.20.0 +flax==0.12.1 +frozenlist==1.8.0 +fsspec==2025.10.0 +ftfy==6.3.1 +gradio==6.0.1 +gradio-client==2.0.0 +groovy==0.1.2 +h11==0.16.0 +hf-xet==1.2.0 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.36.0 +humanize==4.14.0 +idna==3.11 +imageio==2.37.2 +imageio-ffmpeg==0.6.0 +importlib-metadata==8.7.0 +importlib-resources==6.5.2 +inquirerpy==0.3.4 +jax==0.8.1 +-e file:///home/gcpuser/sky_workdir/bonsai +jaxlib==0.8.1 +jaxtyping==0.3.3 +jinja2==3.1.6 +libtpu==0.0.30 +markdown-it-py==4.0.0 +markupsafe==3.0.3 +mdurl==0.1.2 +ml-dtypes==0.5.4 +modelscope==1.32.0 +mpmath==1.3.0 +msgpack==1.1.2 +multidict==6.7.0 +nest-asyncio==1.6.0 +networkx==3.6 +numpy==2.3.5 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-nccl-cu12==2.27.5 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvshmem-cu12==3.3.20 +nvidia-nvtx-cu12==12.8.90 +opencv-python==4.11.0.86 +opt-einsum==3.4.0 +optax==0.2.6 +orbax-checkpoint==0.11.28 +orjson==3.11.4 +packaging==25.0 +pandas==2.3.3 +pfzy==0.3.4 +pillow==12.0.0 +prompt-toolkit==3.0.52 +propcache==0.4.1 +protobuf==6.33.1 +psutil==7.1.3 +pycparser==2.23 +pydantic==2.12.4 +pydantic-core==2.41.5 +pydub==0.25.1 +pygments==2.19.2 +python-dateutil==2.9.0.post0 +python-multipart==0.0.20 +pytz==2025.2 +pyyaml==6.0.3 +regex==2025.11.3 +requests==2.32.5 +rich==14.2.0 +safehttpx==0.1.7 +safetensors==0.7.0 +scipy==1.16.3 +semantic-version==2.10.0 +setuptools==80.9.0 +shellingham==1.5.4 +simplejson==3.20.2 +six==1.17.0 +sniffio==1.3.1 +starlette==0.50.0 +sympy==1.14.0 +tensorstore==0.1.79 +tokenizers==0.22.1 +tomlkit==0.13.3 +toolz==1.1.0 +torch==2.9.1 +torchvision==0.24.1 +tqdm==4.67.1 +transformers==4.57.3 +treescope==0.1.10 +triton==3.5.1 +typer==0.20.0 +typing-extensions==4.15.0 +typing-inspection==0.4.2 +tzdata==2025.2 +urllib3==2.5.0 +uvicorn==0.38.0 +wadler-lindig==0.1.7 +wcwidth==0.2.14 +websocket-client==1.9.0 +yarl==1.22.0 +zipp==3.23.0 \ No newline at end of file