Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
54fa243
wan2.1-t2v-1.3b forward
labyrinth-ssr Nov 21, 2025
57de054
fix
labyrinth-ssr Nov 21, 2025
281ebfa
fix
labyrinth-ssr Nov 21, 2025
ec215b4
compact code
labyrinth-ssr Nov 21, 2025
9d5d3a5
feat: vae deocder
labyrinth-ssr Nov 24, 2025
3a260a8
fix
labyrinth-ssr Nov 24, 2025
08e562f
fix
labyrinth-ssr Nov 24, 2025
884173a
fix
labyrinth-ssr Nov 24, 2025
25806fd
add t5
labyrinth-ssr Nov 25, 2025
48d185e
add params
labyrinth-ssr Nov 25, 2025
3cdb157
fix params
labyrinth-ssr Nov 25, 2025
6f012cd
fix params
labyrinth-ssr Nov 25, 2025
4e0e519
revert qwen3
labyrinth-ssr Nov 25, 2025
d4e642c
Merge branch 'jax-ml:main' into dev-wan
Iamleos Nov 25, 2025
678134c
fix
labyrinth-ssr Nov 25, 2025
a00da1b
Merge remote-tracking branch 'pri/dev-wan' into dev-wan
labyrinth-ssr Nov 25, 2025
ecbe804
fix: vae use rmsnorm
labyrinth-ssr Nov 26, 2025
07d099a
fix vae param
labyrinth-ssr Nov 26, 2025
aec55de
fix vae
labyrinth-ssr Nov 26, 2025
648f94a
fix upblocks
labyrinth-ssr Nov 26, 2025
b4cc285
fix upsample
labyrinth-ssr Nov 26, 2025
3677774
fix param
labyrinth-ssr Nov 26, 2025
249bc12
fix noise latent input dim
labyrinth-ssr Nov 26, 2025
c970b2e
fix: lazy compute rope freqs for jit
labyrinth-ssr Nov 26, 2025
ee60cbe
fix: rmsnorm implement
labyrinth-ssr Nov 26, 2025
7d933fe
fix: upsample resample param mapping
labyrinth-ssr Nov 26, 2025
39ede1a
fix: static latent mean and std
labyrinth-ssr Nov 26, 2025
27456dd
fix: static latent mean and std in decoder
labyrinth-ssr Nov 26, 2025
57cedb9
add save video
labyrinth-ssr Nov 26, 2025
85cc4c4
add compare test
labyrinth-ssr Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added bonsai/models/wan2/__init__.py
Empty file.
553 changes: 553 additions & 0 deletions bonsai/models/wan2/modeling.py

Large diffs are not rendered by default.

648 changes: 648 additions & 0 deletions bonsai/models/wan2/params.py

Large diffs are not rendered by default.

339 changes: 339 additions & 0 deletions bonsai/models/wan2/t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX/Flax implementation of T5 encoder for Wan2.1-T2V-1.3B.

Modified from transformers.models.t5.modeling_t5
Converted from PyTorch to JAX/Flax NNX.
"""

import math
from typing import Optional

import jax
import jax.numpy as jnp
from flax import nnx
from jaxtyping import Array


def gelu(x: Array) -> Array:
"""GELU activation function."""
return 0.5 * x * (1.0 + jnp.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.pow(x, 3.0))))


class T5LayerNorm(nnx.Module):
"""T5 Layer Normalization (RMS normalization without centering)."""

def __init__(self, dim: int, eps: float = 1e-6, *, rngs: nnx.Rngs):
self.dim = dim
self.eps = eps
self.weight = nnx.Param(jnp.ones(dim))

def __call__(self, x: Array) -> Array:
# RMS normalization: x / sqrt(mean(x^2))
variance = jnp.mean(x.astype(jnp.float32) ** 2, axis=-1, keepdims=True)
x = x * jax.lax.rsqrt(variance + self.eps)
return self.weight.value * x


class T5Attention(nnx.Module):
"""T5 Multi-head attention."""

def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.1, *, rngs: nnx.Rngs):
assert dim_attn % num_heads == 0
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads

# Linear projections
self.q = nnx.Linear(dim, dim_attn, use_bias=False, rngs=rngs)
self.k = nnx.Linear(dim, dim_attn, use_bias=False, rngs=rngs)
self.v = nnx.Linear(dim, dim_attn, use_bias=False, rngs=rngs)
self.o = nnx.Linear(dim_attn, dim, use_bias=False, rngs=rngs)
self.dropout = nnx.Dropout(dropout, rngs=rngs)

def __call__(
self,
x: Array,
context: Optional[Array] = None,
mask: Optional[Array] = None,
pos_bias: Optional[Array] = None,
deterministic: bool = False,
) -> Array:
"""
Args:
x: [B, L1, C] query
context: [B, L2, C] key/value context, defaults to x for self-attention
mask: [B, L2] or [B, L1, L2] attention mask
pos_bias: [B, num_heads, L1, L2] position bias
"""
context = x if context is None else context
b = x.shape[0]
n = self.num_heads
c = self.head_dim

# Compute Q, K, V
q = self.q(x).reshape(b, -1, n, c) # [B, L1, num_heads, head_dim]
k = self.k(context).reshape(b, -1, n, c) # [B, L2, num_heads, head_dim]
v = self.v(context).reshape(b, -1, n, c) # [B, L2, num_heads, head_dim]

# Attention bias
attn_bias = jnp.zeros((b, n, q.shape[1], k.shape[1]))
if pos_bias is not None:
attn_bias = attn_bias + pos_bias
if mask is not None:
# Expand mask to attention shape
if mask.ndim == 2:
mask = mask[:, None, None, :] # [B, 1, 1, L2]
else:
mask = mask[:, None, :, :] # [B, 1, L1, L2]
attn_bias = jnp.where(mask == 0, jnp.finfo(x.dtype).min, attn_bias)

# Compute attention (T5 does not use scaling)
attn = jnp.einsum("binc,bjnc->bnij", q, k) + attn_bias
attn = jax.nn.softmax(attn, axis=-1)
x = jnp.einsum("bnij,bjnc->binc", attn, v)

# Output projection
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x, deterministic=deterministic)
return x


class T5FeedForward(nnx.Module):
"""T5 Feed-forward network with gated activation."""

def __init__(self, dim: int, dim_ffn: int, dropout: float = 0.1, *, rngs: nnx.Rngs):
self.dim = dim
self.dim_ffn = dim_ffn

# Gate and projection layers
self.gate = nnx.Linear(dim, dim_ffn, use_bias=False, rngs=rngs)
self.fc1 = nnx.Linear(dim, dim_ffn, use_bias=False, rngs=rngs)
self.fc2 = nnx.Linear(dim_ffn, dim, use_bias=False, rngs=rngs)
self.dropout = nnx.Dropout(dropout, rngs=rngs)

def __call__(self, x: Array, deterministic: bool = False) -> Array:
# Gated activation
x = self.fc1(x) * gelu(self.gate(x))
x = self.dropout(x, deterministic=deterministic)
x = self.fc2(x)
x = self.dropout(x, deterministic=deterministic)
return x


class T5RelativeEmbedding(nnx.Module):
"""T5 Relative position embeddings."""

def __init__(self, num_buckets: int, num_heads: int, bidirectional: bool, max_dist: int = 128, *, rngs: nnx.Rngs):
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
self.embedding = nnx.Embed(num_buckets, num_heads, rngs=rngs)

def __call__(self, lq: int, lk: int) -> Array:
"""Compute relative position bias.

Args:
lq: Query sequence length
lk: Key sequence length

Returns:
[1, num_heads, lq, lk] relative position bias
"""
# Compute relative positions
q_pos = jnp.arange(lq)[:, None]
k_pos = jnp.arange(lk)[None, :]
rel_pos = k_pos - q_pos # [lq, lk]

# Convert to buckets
rel_pos_buckets = self._relative_position_bucket(rel_pos)

# Get embeddings
rel_pos_embeds = self.embedding(rel_pos_buckets) # [lq, lk, num_heads]
rel_pos_embeds = rel_pos_embeds.transpose(2, 0, 1)[None, :, :, :] # [1, num_heads, lq, lk]
return rel_pos_embeds

def _relative_position_bucket(self, rel_pos: Array) -> Array:
"""Convert relative positions to bucket indices."""
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).astype(jnp.int32) * num_buckets
rel_pos = jnp.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = 0
rel_pos = -jnp.minimum(rel_pos, jnp.zeros_like(rel_pos))

# Small vs large positions
max_exact = num_buckets // 2
is_small = rel_pos < max_exact

# Logarithmic bucketing for large positions
rel_pos_large = max_exact + (
jnp.log(rel_pos.astype(jnp.float32) / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(jnp.int32)
rel_pos_large = jnp.minimum(rel_pos_large, num_buckets - 1)

rel_buckets = rel_buckets + jnp.where(is_small, rel_pos, rel_pos_large)
return rel_buckets


class T5SelfAttention(nnx.Module):
"""T5 Self-attention block with feed-forward."""

def __init__(
self,
dim: int,
dim_attn: int,
dim_ffn: int,
num_heads: int,
num_buckets: int,
shared_pos: bool = True,
dropout: float = 0.1,
*,
rngs: nnx.Rngs,
):
self.shared_pos = shared_pos

# Layers
self.norm1 = T5LayerNorm(dim, rngs=rngs)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, rngs=rngs)
self.norm2 = T5LayerNorm(dim, rngs=rngs)
self.ffn = T5FeedForward(dim, dim_ffn, dropout, rngs=rngs)

if not shared_pos:
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, rngs=rngs)
else:
self.pos_embedding = None

def __call__(
self, x: Array, mask: Optional[Array] = None, pos_bias: Optional[Array] = None, deterministic: bool = False
) -> Array:
# Get position bias
if self.shared_pos:
e = pos_bias
else:
e = self.pos_embedding(x.shape[1], x.shape[1])

# Self-attention
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e, deterministic=deterministic)
# Feed-forward
x = x + self.ffn(self.norm2(x), deterministic=deterministic)
return x


class T5Encoder(nnx.Module):
"""T5 Encoder."""

def __init__(
self,
vocab_size: int,
dim: int,
dim_attn: int,
dim_ffn: int,
num_heads: int,
num_layers: int,
num_buckets: int,
shared_pos: bool = True,
dropout: float = 0.1,
*,
rngs: nnx.Rngs,
):
self.dim = dim
self.shared_pos = shared_pos

# Layers
self.token_embedding = nnx.Embed(vocab_size, dim, rngs=rngs)
if shared_pos:
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, rngs=rngs)
else:
self.pos_embedding = None
self.dropout = nnx.Dropout(dropout, rngs=rngs)
self.blocks = nnx.List(
[
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, rngs=rngs)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim, rngs=rngs)

def __call__(self, ids: Array, mask: Optional[Array] = None, deterministic: bool = True) -> Array:
"""Encode input tokens.

Args:
ids: [B, L] input token IDs
mask: [B, L] attention mask (1 for valid tokens)
deterministic: whether to disable dropout

Returns:
[B, L, dim] encoded representations
"""
x = self.token_embedding(ids)
x = self.dropout(x, deterministic=deterministic)

# Compute shared position bias if needed
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.shared_pos else None

# Apply transformer blocks
for block in self.blocks:
x = block(x, mask, pos_bias=e, deterministic=deterministic)

x = self.norm(x)
x = self.dropout(x, deterministic=deterministic)
return x


class T5EncoderModel(nnx.Module):
"""T5 Encoder-only model for text encoding.

This is a wrapper for the T5 encoder configured for UMT5-XXL.
"""

def __init__(self, *, rngs: nnx.Rngs):
"""Initialize UMT5-XXL encoder."""
# UMT5-XXL configuration
self.encoder = T5Encoder(
vocab_size=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
num_layers=24,
num_buckets=32,
shared_pos=False, # UMT5 uses per-layer position embeddings
dropout=0.1,
rngs=rngs,
)

def __call__(self, input_ids: Array, attention_mask: Optional[Array] = None, deterministic: bool = True) -> Array:
"""Encode text.

Args:
input_ids: [B, L] token IDs
attention_mask: [B, L] attention mask (1 for valid tokens, 0 for padding)
deterministic: whether to disable dropout (True for inference)

Returns:
[B, L, 4096] encoded text embeddings
"""
return self.encoder(input_ids, mask=attention_mask, deterministic=deterministic)


__all__ = ["T5Encoder", "T5EncoderModel"]
Empty file.
Loading