diff --git a/bonsai/models/gemma3/README.md b/bonsai/models/gemma3/README.md new file mode 100644 index 0000000..6f1e668 --- /dev/null +++ b/bonsai/models/gemma3/README.md @@ -0,0 +1,28 @@ +# Qwen3 in JAX + +This directory contains a pure JAX implementation of the [Gemma3 model](https://deepmind.google/models/gemma/gemma-3/), using the [Flax NNX](https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html) API. Note that you need an access token to download the model weights. + + +**This is currently in progress but passing numerics checks. Working on cleaning up the code and optimizing before the official PR.** + + +## Model Configuration Support Status + + +### Running this model + + +```sh +python3 -m bonsai.models.gemma3.tests.run_model +``` + + +## How to contribute to this model + +### Remaining Tasks + +1. Properly implement sharding (vision, then text) +2. Implement with batching. Need this for FSDP. +3. Optimize based on the profiling. +4. Clean up code (variable names, etc.). Simplify unused configs (marked these with TODO) +5. Update to include other model sizes diff --git a/bonsai/models/gemma3/modeling.py b/bonsai/models/gemma3/modeling.py new file mode 100644 index 0000000..d3680a3 --- /dev/null +++ b/bonsai/models/gemma3/modeling.py @@ -0,0 +1,765 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass, field +from enum import Enum +from functools import partial +from typing import Any, Optional, Tuple, TypeAlias + +import flax +import jax +import jax.numpy as jnp +from flax import nnx +from jax import P +from jax.interpreters import pxla +from jax.sharding import PartitionSpec, get_abstract_mesh, reshard +from jaxtyping import Array, Float + + +class AttentionMode(Enum): + FULL = "full_attention" + SLIDE = "sliding_attention" + + +class ShardMode(Enum): + FSDP = "fsdp" + TP = "tp" + + +def _make_attn_types(full_freq: int): + # Fix this (5x slide 1x full) x 5 + (4x slide) + return [AttentionMode.FULL if i % full_freq == full_freq - 1 else AttentionMode.SLIDE for i in range(34)] + + +# TODO: Need to get these implemented and follow the activations +@dataclass(slots=True, frozen=True) +class VisionShardingCfg: + qk: PartitionSpec + qb: PartitionSpec + kk: PartitionSpec + kb: PartitionSpec + vk: PartitionSpec + vb: PartitionSpec + ok: PartitionSpec + ob: PartitionSpec + fc1k: PartitionSpec + fc1b: PartitionSpec + fc2k: PartitionSpec + fc2b: PartitionSpec + + @staticmethod + def no_sharding(): + return VisionShardingCfg.default(False, False) + + @staticmethod + def default(use_fsdp: bool, use_tp: bool): + fsdp = ShardMode.FSDP.value if use_fsdp else None + tp = ShardMode.TP.value if use_tp else None + return VisionShardingCfg( + qk=P(tp, fsdp), + qb=P(fsdp), + kk=P(tp, fsdp), + kb=P(), # TODO + vk=P(tp, fsdp), + vb=P(), # TODO + ok=P(tp, fsdp), + ob=P(), # TODO + fc1k=P(fsdp, tp), + fc1b=P(), # TODO + fc2k=P(tp, fsdp), + fc2b=P(), # TODO + ) + + +@dataclass(slots=True, frozen=True) +class TextShardingCfg: + # attn + qk: PartitionSpec + qb: PartitionSpec + kk: PartitionSpec + kb: PartitionSpec + vk: PartitionSpec + vb: PartitionSpec + ok: PartitionSpec + ob: PartitionSpec + # mlp + dpk: PartitionSpec + dpb: PartitionSpec + gpk: PartitionSpec + gpb: PartitionSpec + upk: PartitionSpec + upb: PartitionSpec + # cache + # TODO + + @staticmethod + def no_sharding(): + return VisionShardingCfg.default(False, False) + + @staticmethod + def default(use_fsdp: bool, use_tp: bool): + fsdp = ShardMode.FSDP.value if use_fsdp else None + tp = ShardMode.TP.value if use_tp else None + return TextShardingCfg( + qk=P(tp, fsdp), + qb=P(fsdp), + kk=P(tp, fsdp), + kb=P(), # TODO + vk=P(tp, fsdp), + vb=P(), # TODO + ok=P(tp, fsdp), + ob=P(), # TODO + dpk=P(tp, fsdp), + dpb=P(), + gpk=P(tp, fsdp), + gpb=P(), + upk=P(tp, fsdp), + upb=P(), + ) + + +@dataclass(slots=True, frozen=True) +class ShardingCfg: + pass + + +@dataclass(frozen=True) +class VisionConfig: + attention_dropout: float # TODO: unused + hidden_size: int + image_size: int + intermediate_size: int + layer_norm_eps: float + num_attention_heads: int + num_channels: int + num_hidden_layers: int + patch_size: int + vision_use_head: bool + shd_cfg: VisionShardingCfg + + @classmethod + def gemma3_4b(cls, use_fsdp: bool, use_tp: bool): + return cls( + attention_dropout=0.0, + hidden_size=1152, + image_size=896, + intermediate_size=4304, + layer_norm_eps=1e-6, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=27, + patch_size=14, + vision_use_head=False, + shd_cfg=VisionShardingCfg.default(use_fsdp, use_tp), + ) + + +@dataclass(frozen=True) +class TextConfig: + attention_bias: bool + attention_dropout: float # TODO: unused + attn_logit_softcapping: Optional[float] # TODO: unused + final_logit_softcapping: Optional[float] # TODO: unused + head_dim: int + hidden_size: int + initializer_range: float # TODO: unused + intermediate_size: int + layer_types: list[AttentionMode] + max_position_embeddings: int # TODO: unused + num_attention_heads: int + num_hidden_layers: int + num_key_value_heads: int + query_pre_attn_scalar: int # TODO: unused + rms_norm_eps: float + rope_full_factor: float + rope_full_theta: float + rope_slide_factor: float + rope_slide_theta: float + sliding_window: int + use_cache: bool # TODO: unused + vocab_size: int + shd_cfg: TextShardingCfg + + @classmethod + def gemma3_4b(cls, use_fsdp: bool, use_tp: bool): + return cls( + attention_bias=False, + attention_dropout=0.0, # TODO: unused + attn_logit_softcapping=None, # TODO: unused + final_logit_softcapping=None, # TODO: unused + head_dim=256, + hidden_size=2560, + initializer_range=0.02, # TODO: unused + intermediate_size=10240, + layer_types=_make_attn_types(6), + max_position_embeddings=131072, # TODO: unused + num_attention_heads=8, + num_hidden_layers=34, + num_key_value_heads=4, + query_pre_attn_scalar=256, # TODO: unused + rms_norm_eps=1e-6, + rope_full_factor=8.0, + rope_full_theta=1000000.0, + rope_slide_factor=1.0, + rope_slide_theta=10000.0, + sliding_window=1024, + use_cache=True, # TODO: unused + vocab_size=262208, + shd_cfg=TextShardingCfg.default(use_fsdp, use_tp), + ) + + +@dataclass(frozen=True) +class ModelConfig: + vision_config: VisionConfig + text_config: TextConfig + mm_tokens_per_image: int + boi_token_index: int # TODO: unused + dtype: str # TODO: unused + eoi_token_index: int # TODO: unused + eos_token_id: list[int] # TODO: unused + image_token_index: int # TODO: unused + initializer_range: float # TODO: unused + + @classmethod + def gemma3_4b(cls, use_fsdp: bool = False, use_tp: bool = False): + return cls( + vision_config=VisionConfig.gemma3_4b(use_fsdp, use_tp), + text_config=TextConfig.gemma3_4b(use_fsdp, use_tp), + mm_tokens_per_image=256, + boi_token_index=255999, # TODO: unused + dtype="bfloat16", # TODO: unused + eoi_token_index=256000, # TODO: unused + eos_token_id=field(default_factory=lambda: [1, 106]), # TODO: unused + image_token_index=262144, # TODO: unused + initializer_range=0.02, # TODO: unused + ) + + +# General Components + + +class SHLinear(nnx.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + *, + use_bias: bool = True, + kernel_sharding=P(), + bias_sharding=P(), + dtype=None, # TODO: Use this + rngs, + ): + kernel_initializer = jax.nn.initializers.lecun_normal() + self.kernel = nnx.Param( + kernel_initializer(rngs.params(), (in_dim, out_dim), dtype=dtype, out_sharding=kernel_sharding) + ) + if use_bias: + self.bias = nnx.Param(jnp.zeros((out_dim,), dtype=dtype, out_sharding=bias_sharding)) + else: + self.bias = nnx.data(jnp.zeros((out_dim,), dtype=dtype, out_sharding=bias_sharding)) + + def __call__(self, x, *, out_sharding=P()): + return jnp.matmul(x, self.kernel, out_sharding=out_sharding) + self.bias + + +# Vision Components + + +# TODO: update to include interpolate_pos_encoding +class SiglipVisionEmbeddings(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.num_patches = (config.image_size // config.patch_size) ** 2 + self.patch_embedding = nnx.Conv( + config.num_channels, + config.hidden_size, + kernel_size=(config.patch_size,) * 2, + strides=(config.patch_size,) * 2, + padding="valid", + rngs=rngs, + ) + # TODO: shard this + self.position_embedding = nnx.Embed(self.num_patches, config.hidden_size, rngs=rngs) + self.position_ids = jnp.expand_dims(jnp.arange(self.num_patches), 0) + + def __call__(self, pixel_values: Array): + patch_embeds = self.patch_embedding(pixel_values) + b, h, w, c = patch_embeds.shape + embeddings = patch_embeds.reshape((b, h * w, c)) + return embeddings + self.position_embedding(self.position_ids) + + +class SiglipAttention(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + hs = config.hidden_size + shd = config.shd_cfg + self.k_proj = SHLinear(hs, hs, kernel_sharding=shd.kk, bias_sharding=shd.kb, rngs=rngs) + self.v_proj = SHLinear(hs, hs, kernel_sharding=shd.vk, bias_sharding=shd.vb, rngs=rngs) + self.q_proj = SHLinear(hs, hs, kernel_sharding=shd.qk, bias_sharding=shd.qb, rngs=rngs) + self.out_proj = SHLinear(hs, hs, kernel_sharding=shd.ok, bias_sharding=shd.ob, rngs=rngs) + + def __call__(self, x: Array, attn_mask: Array | None): + batch_size, seq_length, _ = x.shape + shape = (batch_size, seq_length, self.num_heads, self.head_dim) + q = self.q_proj(x).reshape(shape) + k = self.k_proj(x).reshape(shape) + v = self.v_proj(x).reshape(shape) + + attn = jax.nn.dot_product_attention(q, k, v, mask=attn_mask).reshape(x.shape) + return self.out_proj(attn) + + +class SiglipMLP(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.fc1 = SHLinear(config.hidden_size, config.intermediate_size, rngs=rngs) + self.fc2 = SHLinear(config.intermediate_size, config.hidden_size, rngs=rngs) + + def __call__(self, x: Array): + return self.fc2(jax.nn.gelu(self.fc1(x))) + + +class SiglipEncoderLayer(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.layer_norm1 = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.self_attn = SiglipAttention(config, rngs=rngs) + self.layer_norm2 = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.mlp = SiglipMLP(config, rngs=rngs) + + def __call__(self, x: Array, attn_mask: Array | None): + hidden = self.layer_norm1(x) + hidden = self.self_attn(hidden, attn_mask) + hidden = x + hidden + x = hidden + hidden = self.layer_norm2(hidden) + hidden = self.mlp(hidden) + return hidden + x + + +class SiglipEncoder(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.layers = nnx.List([SiglipEncoderLayer(config, rngs=rngs) for _ in range(config.num_hidden_layers)]) + + def __call__(self, x: Array, attn_mask: Array | None): + for l in self.layers: + x = l(x, attn_mask) + return x + + +# TODO: Skip for now since not in 4b, but test later +class SiglipMultiheadAttentionPoolingHead(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.probe = nnx.Param(nnx.initializers.normal(stddev=0.02)(rngs.params(), (1, 1, config.hidden_size))) + self.attention = nnx.MultiHeadAttention(config.num_attention_heads, config.hidden_size, rngs=rngs) + self.layernorm = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.mlp = SiglipMLP(config, rngs=rngs) + + def __call__(self, *args, **kwargs): + raise NotImplementedError("Not yet implemented") + + +class SiglipVisionTransformer(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.embeddings = SiglipVisionEmbeddings(config, rngs=rngs) + self.encoder = SiglipEncoder(config, rngs=rngs) + self.post_layernorm = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + def __call__(self, pixel_values: Array): + x = self.embeddings(pixel_values) + x = self.encoder(x, attn_mask=None) + x = self.post_layernorm(x) + if self.use_head: + x = self.head(x) + return x + + +# Language components + + +# from qwen3 +class LayerCache(nnx.Module): + def __init__(self, cfg: ModelConfig, batch_size: int, cache_size: int, dtype: jnp.dtype): + cache_shape = (batch_size, cache_size, cfg.num_key_value_heads, cfg.head_dim) + self.k_cache = nnx.Cache(jnp.zeros(cache_shape, dtype=dtype)) + self.v_cache = nnx.Cache(jnp.zeros(cache_shape, dtype=dtype)) + self.size = self.k_cache.shape[1] + self.start_ind = nnx.Variable(-1 * jnp.ones((batch_size,), dtype=jnp.int32)) + self.cur_ind = nnx.Variable(jnp.zeros((), dtype=jnp.int32)) # scalar for compute efficiency. + + +Cache: TypeAlias = list[LayerCache] + + +def init_cache( + cfg: ModelConfig, batch_size: int, token_len: int, generate_steps: int, dtype: jnp.dtype = jnp.bfloat16 +) -> Cache: + cache_size = 2 ** math.ceil(math.log2(max(token_len + generate_steps, 1))) # Pad for a sharding-friendly size. + return [ + LayerCache(cfg.text_config, batch_size, cache_size, dtype) for _ in range(cfg.text_config.num_hidden_layers) + ] + + +class Gemma3RMSNorm(nnx.Module): + def __init__(self, dim: int, eps: float, *, rngs: nnx.Rngs): + self.scale = nnx.Param(nnx.initializers.zeros_init()(rngs.params(), dim)) + self.eps = eps + + @jax.named_scope("rms_norm") + def __call__(self, x: Array) -> Array: + dtype = x.dtype + xf32 = x.astype(jnp.float32) + out = xf32 * jax.lax.rsqrt(jnp.square(xf32).mean(-1, keepdims=True) + self.eps) + out = out * (1.0 + self.scale.value.astype(jnp.float32)) + return out.astype(dtype) + + +class Gemma3TextScaledWordEmbedding(nnx.Module): + def __init__(self, cfg: TextConfig, *, rngs: nnx.Rngs): + self.weight = nnx.Embed(cfg.vocab_size, cfg.hidden_size, rngs=rngs) + self.embed_scale = jnp.array(cfg.hidden_size**0.5, dtype=jnp.bfloat16).astype(jnp.float32) + + def __call__(self, input_ids: Array): + return self.weight(input_ids) * self.embed_scale + + +# below is from qwen3 + + +def _generate_pos_embeddings( + positions: jax.Array, + head_dim: int, + rope_theta: int = 1_000_000, + factor: float = 1.0, +) -> tuple[jax.Array, jax.Array]: + # Forked from: jax-llm-examples/qwen3/qwen3_jax/model.py;l=571 + fraction = jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim + timescale = rope_theta**fraction + rotational_frequency = 1.0 / timescale + rotational_frequency /= factor + # Use high-precision einsum to prevent catastrophic bfloat16 rounding (ex: 257→256), as sin(257) differs from sin(256). + sinusoid_inp = jnp.einsum("BT,k->BTk", positions, rotational_frequency, precision=jax.lax.Precision.HIGHEST) + return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) + + +def apply_rope(x: jax.Array, sin: jax.Array, cos: jax.Array) -> jax.Array: + assert x.ndim == 4 and sin.ndim == 3 and cos.ndim == 3 + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + # [B, T, head_dim] -> [B, h, T, head_dim] + sin, cos = sin[:, :, None, :], cos[:, :, None, :] + return jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1).astype(x.dtype) + + +def count_left_pads(x: jax.Array) -> int: + """Count left padding tokens.""" + return jnp.sum(jnp.cumsum(x != 0, axis=-1) == 0, -1) + + +def count_right_pads(x: jax.Array, pad_id) -> int: + result = jnp.where( + jnp.all(x == pad_id, axis=1), x.shape[1], jnp.argmin(jnp.flip(x == pad_id, axis=1).astype(jnp.int32), axis=1) + ) + return jnp.max(result) + + +def compute_positions_from_segment_ids(seg_ids: Array): + return jax.vmap(lambda row: jnp.where(row != 0, jnp.arange(seg_ids.shape[1]) - jnp.argmax(row), 2**30))(seg_ids) + + +## Above is from qwen3 + + +def repeat_kv(hidden_states: Array, n_rep: int): + b, t, kv_heads, head_dim = hidden_states.shape + hidden_states = jnp.expand_dims(hidden_states, axis=3) + hidden_states = jnp.repeat(hidden_states, repeats=n_rep, axis=3) + return hidden_states.reshape(b, t, kv_heads * n_rep, head_dim) + + +class Gemma3Attention(nnx.Module): + def __init__(self, config: TextConfig, layer_idx: int, *, rngs: nnx.Rngs): + self.config = config + self.layer_idx = layer_idx + self.use_sliding = config.layer_types[layer_idx] == AttentionMode.SLIDE + self.num_kv_heads = self.config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + shd = config.shd_cfg + self.q_proj = SHLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + use_bias=config.attention_bias, + kernel_sharding=shd.qk, + bias_sharding=shd.qb, + rngs=rngs, + ) + self.k_proj = SHLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + use_bias=config.attention_bias, + kernel_sharding=shd.kk, + bias_sharding=shd.kb, + rngs=rngs, + ) + self.v_proj = SHLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + use_bias=config.attention_bias, + kernel_sharding=shd.vk, + bias_sharding=shd.vb, + rngs=rngs, + ) + self.o_proj = SHLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + use_bias=config.attention_bias, + kernel_sharding=shd.ok, + bias_sharding=shd.ob, + rngs=rngs, + ) + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, rngs=rngs) + self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, rngs=rngs) + + self.rope_theta = config.rope_slide_theta if self.use_sliding else config.rope_full_theta + self.factor = config.rope_slide_factor if self.use_sliding else config.rope_full_factor + + self.n_rep = config.num_attention_heads // config.num_key_value_heads + self.scale = config.head_dim**-0.5 + + def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, mask: Array | None) -> Array: + # get projections + new_shape = (*x.shape[:-1], -1, self.head_dim) + q = self.q_norm(self.q_proj(x).reshape(new_shape)) + k = self.k_norm(self.k_proj(x).reshape(new_shape)) + v = self.v_proj(x).reshape(new_shape) + + # Apply rope + left_pads = count_left_pads(segment_ids) + cache.start_ind.value = jnp.where(cache.start_ind.value < 0, left_pads, cache.start_ind.value) + position_ids = compute_positions_from_segment_ids(segment_ids) + cache.cur_ind.value + sin, cos = _generate_pos_embeddings(position_ids, self.head_dim, self.rope_theta, factor=self.factor) + q = apply_rope(q, sin, cos) + k = apply_rope(k, sin, cos) + + # Update cache + slice_indices = (0, cache.cur_ind.value, 0, 0) + cache.k_cache.value = jax.lax.dynamic_update_slice(cache.k_cache.value, k, slice_indices) + cache.v_cache.value = jax.lax.dynamic_update_slice(cache.v_cache.value, v, slice_indices) + + k, v = repeat_kv(cache.k_cache.value, self.n_rep), repeat_kv(cache.v_cache.value, self.n_rep) + qkv = jax.nn.dot_product_attention(q, k, v, is_causal=False, mask=mask, scale=self.scale) + + t = x.shape[1] + cache.cur_ind.value = cache.cur_ind.value + t + return self.o_proj(qkv.reshape(*x.shape[:-1], -1)) + + +class Gemma3MLP(nnx.Module): + def __init__(self, config: TextConfig, *, rngs: nnx.Rngs): + self.config = config + shd = config.shd_cfg + self.gate_proj = SHLinear( + config.hidden_size, + config.intermediate_size, + use_bias=False, + kernel_sharding=shd.gpk, + bias_sharding=shd.gpb, + rngs=rngs, + ) + self.up_proj = SHLinear( + config.hidden_size, + config.intermediate_size, + use_bias=False, + kernel_sharding=shd.upk, + bias_sharding=shd.upb, + rngs=rngs, + ) + self.down_proj = SHLinear( + config.intermediate_size, + config.hidden_size, + use_bias=False, + kernel_sharding=shd.dpk, + bias_sharding=shd.dpb, + rngs=rngs, + ) + + def __call__(self, x: Array): + return self.down_proj(jax.nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + + +class Gemma3DecoderLayer(nnx.Module): + def __init__(self, config: TextConfig, layer_idx: int, *, rngs: nnx.Rngs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx, rngs=rngs) + self.mlp = Gemma3MLP(config, rngs=rngs) + self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + + def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, mask: Array | None) -> Array: + res = x + x = self.input_layernorm(x) + x = self.self_attn(x, cache, segment_ids, mask=mask) + x = self.post_attention_layernorm(x) + x = res + x + res = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + return x + res + + @property + def head_dim(self): + return self.o_proj.shape[1] + + +class Gemma3TextModel(nnx.Module): + def __init__(self, config: TextConfig, *, rngs: nnx.Rngs): + self.config = config + self.layers = nnx.List( + [Gemma3DecoderLayer(config, layer_idx, rngs=rngs) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + + def __call__(self, x, cache: Cache, segment_ids: Array, sliding_mask: Array | None, causal_mask: Array | None): + for lt, c, layer in zip(self.config.layer_types, cache, self.layers): + mask = sliding_mask if lt == AttentionMode.SLIDE else causal_mask + x = layer(x, c, segment_ids, mask) + return self.norm(x) + + +class Gemma3MultiModalProjector(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + vhs = config.vision_config.hidden_size + ths = config.text_config.hidden_size + eps = config.vision_config.layer_norm_eps + self.mm_input_projection_weight = nnx.Param(jnp.zeros((vhs, ths)), rngs=rngs) + self.mm_soft_emb_norm = Gemma3RMSNorm(vhs, eps=eps, rngs=rngs) + self.patches_per_img = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_img // self.tokens_per_side + + def __call__(self, vision_outputs: Array) -> Array: + b, _, t = vision_outputs.shape + vision_outputs = vision_outputs.swapaxes(1, 2).reshape(b, t, self.patches_per_img, self.patches_per_img) + # TODO: update this to get rid of the None and 0. + # Might have to write my own avg pool. + x = flax.linen.avg_pool( + vision_outputs[:, :, :, :, None], + window_shape=(1, 1, self.kernel_size, self.kernel_size), + strides=(1, 1, self.kernel_size, self.kernel_size), + )[:, :, :, :, 0] + x = x.reshape(b, t, -1) + x = x.swapaxes(1, 2) + x = self.mm_soft_emb_norm(x) + x = jnp.matmul(x, self.mm_input_projection_weight.value) + return x.astype(vision_outputs.dtype) + + +def make_causal_mask(layer_cache: LayerCache, token_type_ids: Array): + b, t = token_type_ids.shape + c = layer_cache.size + seq_arange = jnp.arange(t) + cache_arange = jnp.arange(c) + causal_mask = seq_arange[:, None] - cache_arange[None, :] >= -layer_cache.cur_ind + tti = token_type_ids.astype(jnp.bool_) + cache_padded_tti = jnp.concat([tti, jnp.zeros((b, c - t), dtype=jnp.bool_)], axis=-1) + image_or_mask = tti[:, None, :, None] & cache_padded_tti[:, None, None, :] + causal_mask = causal_mask.astype(jnp.bool_) | image_or_mask + return causal_mask + + +def make_window_mask(layer_cache, token_type_ids, slide_size: int): + causal_mask = make_causal_mask(layer_cache, token_type_ids) + *_, t, c = causal_mask.shape + seq_arange = jnp.arange(t) + cache_arange = jnp.arange(c) + slide = seq_arange[:, None] - cache_arange[None, :] < slide_size + return causal_mask & slide + + +def merge_modalities(img_emb: Array, text_emb: Array, token_mask: Array) -> Array: + # This function fills the image tokens into the text_emb sequence + # The token_mask tells us where the image tokens are (0 for text, 1 for image) + # image_emb is (Li, D) + # text_emb is (Lt, D) + # token_mask is (Lt) + # We have Li < Lt + img_indices = jnp.cumsum(token_mask) - 1 + safe_indices = jnp.clip(img_indices, 0, img_emb.shape[0] - 1) + aligned_images = img_emb[safe_indices] + return jnp.where(token_mask[:, None], aligned_images, text_emb) + + +def batched_merge_modalities(img_emb: Array, text_emb: Array, token_mask: Array) -> Array: + # image_emb is (B, Li, D) + # text_emb is (B, Lt, D) + # token_mask is (B, Lt) + # We have Li < Lt + return jax.vmap(merge_modalities)(img_emb, text_emb, token_mask) + + +class Gemma3Model(nnx.Module): + def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): + self.sliding_window_size = cfg.text_config.sliding_window + self.embed_tokens = Gemma3TextScaledWordEmbedding(cfg.text_config, rngs=rngs) + self.vision_tower = SiglipVisionTransformer(cfg.vision_config, rngs=rngs) + self.multi_modal_projector = Gemma3MultiModalProjector(cfg, rngs=rngs) + self.language_model = Gemma3TextModel(cfg.text_config, rngs=rngs) + + def __call__( + self, input_ids: Array, pixel_values: Array, cache: Cache, segment_ids: Array, token_type_ids: Array + ) -> Array: + assert input_ids.shape == token_type_ids.shape + causal_mask = make_causal_mask(cache[0], token_type_ids) + sliding_mask = make_window_mask(cache[0], token_type_ids, slide_size=self.sliding_window_size) + + inputs_embeds = self.embed_tokens(input_ids) + + # Merge text and images + if pixel_values is not None: + vision_outputs = self.vision_tower(pixel_values) + image_features = self.multi_modal_projector(vision_outputs) + + image_features = image_features.astype(inputs_embeds.dtype) + inputs_embeds = batched_merge_modalities(image_features, inputs_embeds, token_type_ids) + + out = self.language_model(inputs_embeds, cache, segment_ids, sliding_mask, causal_mask) + out = self.embed_tokens.weight.attend(out) + return out + + +@jax.jit +def forward( + model: nnx.Module, cache: Cache, input_ids: Array, pixel_values: Array, segment_ids: Array, token_type_ids +) -> tuple[Array, nnx.Cache]: + logits = model(input_ids, pixel_values, cache, segment_ids, token_type_ids) + return logits[:, -1, :], cache diff --git a/bonsai/models/gemma3/params.py b/bonsai/models/gemma3/params.py new file mode 100644 index 0000000..8b37fdc --- /dev/null +++ b/bonsai/models/gemma3/params.py @@ -0,0 +1,266 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Parameter helpers for bonsai.models.gemma3. + +Add functions to load or convert pretrained checkpoints and to return +default configuration values used by the model implementation. +""" + +import logging +import re +from enum import Enum + +import jax +import safetensors.flax as safetensors +from etils import epath +from flax import nnx + +from bonsai.models.gemma3 import modeling as model_lib + + +class Transform(Enum): + """ + Specifies default transformation types for model parameter names. + """ + + DEFAULT = None + BIAS = None + LINEAR = ((1, 0), None) + CONV2D = ((2, 3, 1, 0), None) + EMBED = None + + +def _get_key_and_transform_mapping(): + # Mapping st_keys -> (nnx_keys, (permute_rule, reshape_rule)). + return { + r"^language_model\.model\.embed_tokens\.weight$": ( + r"embed_tokens\.weight\.embedding", + Transform.EMBED, + ), + r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight$": ( + r"language_model\.layers\.\1\.input_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight$": ( + r"language_model\.layers\.\1\.mlp\.down_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight$": ( + r"language_model\.layers\.\1\.mlp\.gate_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight$": ( + r"language_model\.layers\.\1\.mlp\.up_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight$": ( + r"language_model\.layers\.\1\.post_attention_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.post_feedforward_layernorm\.weight$": ( + r"language_model\.layers\.\1\.post_feedforward_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.pre_feedforward_layernorm\.weight$": ( + r"language_model\.layers\.\1\.pre_feedforward_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.k_norm\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.k_norm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.k_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.o_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.q_norm\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.q_norm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.q_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.v_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.norm\.weight$": (r"language_model\.norm\.scale", Transform.DEFAULT), + r"^multi_modal_projector\.mm_input_projection_weight$": ( + r"multi_modal_projector\.mm_input_projection_weight", + Transform.DEFAULT, + ), + r"^multi_modal_projector\.mm_soft_emb_norm\.weight$": ( + r"multi_modal_projector\.mm_soft_emb_norm\.scale", + Transform.DEFAULT, + ), + r"^vision_tower\.vision_model\.embeddings\.patch_embedding\.bias$": ( + r"vision_tower\.embeddings\.patch_embedding\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.embeddings\.patch_embedding\.weight$": ( + r"vision_tower\.embeddings\.patch_embedding\.kernel", + Transform.CONV2D, + ), + r"^vision_tower\.vision_model\.embeddings\.position_embedding\.weight$": ( + r"vision_tower\.embeddings\.position_embedding\.embedding", + Transform.EMBED, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.layer_norm(\d+)\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.layer_norm\2\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.layer_norm(\d+)\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.layer_norm\2\.scale", + Transform.DEFAULT, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc(\d+)\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.mlp\.fc\2\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc(\d+)\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.mlp\.fc\2\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.k_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.k_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.out_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.out_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.q_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.q_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.v_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.v_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.post_layernorm\.bias$": ( + r"vision_tower\.post_layernorm\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.post_layernorm\.weight$": ( + r"vision_tower\.post_layernorm\.scale", + Transform.DEFAULT, + ), + } + + +def _st_key_to_jax_key(mapping, source_key): + """Map a safetensors key to exactly one JAX key & transform, else warn/error.""" + subs = [ + (re.sub(pat, repl, source_key), transform) + for pat, (repl, transform) in mapping.items() + if re.match(pat, source_key) + ] + if not subs: + logging.warning(f"No mapping found for key: {source_key!r}") + return None, None + if len(subs) > 1: + keys = [s for s, _ in subs] + raise ValueError(f"Multiple mappings found for {source_key!r}: {keys}") + return subs[0] + + +def _assign_weights(keys, tensor, state_dict, st_key, transform): + """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" + key, *rest = keys + if not rest: + if transform is not None: + permute, reshape = transform + if permute: + tensor = tensor.transpose(permute) + if reshape: + tensor = tensor.reshape(reshape) + if tensor.shape != state_dict[key].shape: + raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}") + state_dict[key] = tensor + else: + _assign_weights(rest, tensor, state_dict[key], st_key, transform) + + +def _stoi(s): + try: + return int(s) + except ValueError: + return s + + +# TODO: Update to include sharding +def create_gemma3_from_pretrained( + file_dir: str, + cfg: model_lib.ModelConfig, + *, + mesh: jax.sharding.Mesh | None = None, +): + """ + Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model. + + Returns: + A flax.nnx.Model instance with loaded parameters. + """ + files = list(epath.Path(file_dir).expanduser().glob("*.safetensors")) + if not files: + raise ValueError(f"No safetensors found in {file_dir}") + + tensor_dict = {} + for f in files: + tensor_dict |= safetensors.load_file(f) + + gemma3 = model_lib.Gemma3Model(cfg, rngs=nnx.Rngs(0)) + graph_def, abs_state = nnx.split(gemma3) + jax_state = abs_state.to_pure_dict() + + mapping = _get_key_and_transform_mapping() + for st_key, tensor in tensor_dict.items(): + jax_key, transform = _st_key_to_jax_key(mapping, st_key) + if jax_key is None: + continue + keys = [_stoi(k) for k in jax_key.split(r"\.")] + try: + _assign_weights(keys, tensor, jax_state, st_key, transform.value) + except KeyError as e: + print(f"Key error: {keys} at {e}") + except ValueError as e: + print(e) + except Exception as e: + print(keys) + raise e + + return nnx.merge(graph_def, jax_state) diff --git a/bonsai/models/gemma3/tests/run_model.py b/bonsai/models/gemma3/tests/run_model.py new file mode 100644 index 0000000..8646db4 --- /dev/null +++ b/bonsai/models/gemma3/tests/run_model.py @@ -0,0 +1,115 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run a small inference example for Gemma3.""" + +import os +import time + +import jax +import jax.numpy as jnp +import numpy as np +import torch +import tqdm +from huggingface_hub import snapshot_download +from transformers import Gemma3Processor + +from bonsai.models.gemma3 import modeling, params +from bonsai.utils import Sampler + + +def make_input(processor, dtype=torch.float32, msg1=False): + if msg1: + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + }, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + else: + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + + t_inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + t_inputs["pixel_values"] = t_inputs["pixel_values"].to(dtype=dtype) + + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + return n_text, n_img, n_tti + + +def run_model(): + model_name: str = "google/gemma-3-4b-it" + cfg = modeling.ModelConfig() + access_token = os.environ["HF_TOKEN"] + processor = Gemma3Processor.from_pretrained(model_name, token=access_token, use_fast=False) + model_ckpt_path = snapshot_download(model_name) + bonsai_config = modeling.ModelConfig.gemma3_4b() + bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path, bonsai_config) + eot_token_id = processor.tokenizer.convert_tokens_to_ids("") + + # Make inputs + n_text, n_img, n_tti = make_input(processor) + + gen_steps = 500 + batch_size, num_tokens = n_text.shape + cache = modeling.init_cache(cfg, batch_size, num_tokens, gen_steps, jnp.float32) + + source_key = jax.random.key(0) + sampler = jax.jit(Sampler(temperature=1.0, top_p=0.8, top_k=10)) + + all_tokens = [n_text] + for _ in tqdm.trange(gen_steps): + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + out, cache = modeling.forward(bonsai_model, cache, n_text, n_img, segment_ids, n_tti) + + source_key, key = jax.random.split(source_key) + n_text = sampler(out, key=key) + if jnp.all(n_text == eot_token_id): + print("Hit end of token.") + break + all_tokens.append(n_text) + n_tti = jnp.zeros((batch_size, 1), dtype=jnp.int32) + n_img = None + + full_tokens = torch.tensor(jnp.concat(all_tokens, axis=1)) + out_tokens = processor.decode(full_tokens[0], skip_special_tokens=True) + print(out_tokens) + + +if __name__ == "__main__": + run_model() diff --git a/bonsai/models/gemma3/tests/test_outputs_gemma3.py b/bonsai/models/gemma3/tests/test_outputs_gemma3.py new file mode 100644 index 0000000..12bb573 --- /dev/null +++ b/bonsai/models/gemma3/tests/test_outputs_gemma3.py @@ -0,0 +1,720 @@ +import os +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from absl.testing import absltest +from huggingface_hub import snapshot_download +from jax.sharding import AxisType +from jax.typing import DTypeLike +from tqdm import trange +from transformers import AutoModel, AutoProcessor, AutoTokenizer, Gemma3Model, SiglipModel +from transformers.cache_utils import DynamicCache +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask +from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3ForConditionalGeneration +from transformers.models.gemma3.modeling_gemma3 import token_type_ids_mask_function + +from bonsai.models.gemma3 import modeling, params + +SKIP_DONE: bool = True + + +class TestModuleForwardPasses(absltest.TestCase): + # using this for faster testing. This way we can avoid reloading the model. + # Make sure not to modify the Gemma3 model in inconsistent ways between tests. + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_name: str = "google/gemma-3-4b-it" + # self.model_name: str = "google/gemma-3-270m" # This is text only + access_token = os.environ["HF_TOKEN"] + cls.processor = AutoProcessor.from_pretrained(cls.model_name, token=access_token, use_fast=False) + cls.torch_device = "cpu" + cls.mesh = jax.make_mesh(((1, 1)), ("fsdp", "tp"), axis_types=(AxisType.Explicit, AxisType.Explicit)) + jax.set_mesh(cls.mesh) + + ## models + cls.torch_model = ( + Gemma3ForConditionalGeneration.from_pretrained(cls.model_name, dtype="auto") + .to(device=cls.torch_device, dtype=torch.float32) + .eval() + ) + cls.torch_config = cls.torch_model.config + + cls.bonsai_config = modeling.ModelConfig.gemma3_4b() + model_ckpt_path = snapshot_download(cls.model_name) + cls.bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path, cls.bonsai_config) + + cls.batch_size = 1 + cls.cache_size, cls.gen_steps = 512, 10 + + def _upgrade_dtypes(self): + self.bonsai_model.embed_tokens.weight.embedding.value = ( + self.bonsai_model.embed_tokens.weight.embedding.value.astype(jnp.float32) + ) + return + + def _make_torch_input(self): + # returns model inputs: + # KEY SHAPE DTYPE + # input_ids torch.Size([1, 281]) int64 + # attention_mask torch.Size([1, 281]) int64 + # token_type_ids torch.Size([1, 281]) int64 + # pixel_values torch.Size([1, 3, 896, 896]) bfloat16 -> float32 + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + + out = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + out["pixel_values"] = out["pixel_values"].to(dtype=torch.float32) + + return {k: v.to(device=self.torch_device) for k, v in out.items()} + + def _make_bonsai_input(self, torch_inputs): + out = dict() + for k, v in torch_inputs.items(): + tmp = v.detach().cpu().numpy() + if k == "pixel_values": + tmp = np.permute_dims(tmp, (0, 2, 3, 1)) + out[k] = tmp + return out + + # This should be correct for unbatched inputs + def _process_torch_inputs( + self, + input_ids=None, + pixel_values=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + token_type_ids=None, + cache_position=None, + inputs_embeds=None, + labels=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + **lm_kwargs, + ): + # Replace image id with PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.torch_config.image_token_id >= self.torch_config.text_config.vocab_size: + special_image_mask = input_ids == self.torch_config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.torch_model.model.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # Merge text and images + if pixel_values is not None: + image_features = self.torch_model.model.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.torch_model.model.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.torch_config.get_text_config(), + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized + # (e.g. compiled prefill) AND `pixel_values` are not provided. Determining prefill in that case requires + # checking data values, which is not compile-compatible. + is_prefill = ( + not use_cache + or past_key_values is None + or not past_key_values.is_initialized + or pixel_values is not None + ) + if token_type_ids is not None and is_prefill: + # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + + # First find where a new image block starts: 1 if image and previous not image + # The images cannot attend to future images, but can attend to all prev images and to itself + # bidirectionally + is_image = (token_type_ids == 1).to(cache_position.device) + new_image_start = is_image & ~torch.nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] + image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_ids = torch.where( + is_image, image_group_ids, torch.full_like(token_type_ids, -1, device=is_image.device) + ) + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device), image_group_ids, self.torch_config.mm_tokens_per_image + ) + + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + return dict( + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + # This should be correct for unbatched inputs + def _process_torch_inputs_for_decoder_text_model( + self, + attn_type, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + **kwargs, + ): + training = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None and not training: + past_key_values = DynamicCache(config=self.torch_model.model.config.text_config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + sliding_mask_kwargs = mask_kwargs.copy() + + # if self.config.use_bidirectional_attention: + # mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) + # sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window) + + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + position_embeddings_global = self.torch_model.model.language_model.rotary_emb(inputs_embeds, position_ids) + position_embeddings_local = self.torch_model.model.language_model.rotary_emb_local(inputs_embeds, position_ids) + return dict( + hidden_states=inputs_embeds, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask_mapping[attn_type], + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + def _init_nnx_cache(self, batch_size: int, token_len: int, generate_steps: int, dtype): + return modeling.init_cache( + cfg=self.bonsai_config, + batch_size=batch_size, + token_len=token_len, + generate_steps=generate_steps, + dtype=dtype, + ) + + # Vision tests + @unittest.skipIf(SKIP_DONE, "Done") + def test_image_emb(self): + tm = self.torch_model.model.vision_tower.vision_model.embeddings + nm = self.bonsai_model.vision_tower.embeddings + + t_inputs = self._make_torch_input() + n_inputs = self._make_bonsai_input(t_inputs) + tx = t_inputs["pixel_values"] + nx = n_inputs["pixel_values"] + + with torch.no_grad(): + ty = tm(tx) + ny = nm(nx) + + # (1, 4096, 1152) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5) + + @unittest.skipIf(SKIP_DONE, "Done") + def test_siglip_encoder_layer(self): + tm = self.torch_model.model.vision_tower.vision_model.encoder.layers[0] + nm = self.bonsai_model.vision_tower.encoder.layers[0] + + tx = torch.randn((1, 4096, 1152), device=self.torch_device) + nx = tx.detach().cpu().numpy() + + with torch.no_grad(): + ty = tm(tx, None) + ny = nm(nx, None) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4) + + @unittest.skipIf(SKIP_DONE, "Done") + def test_vision_model(self): + # only have deviations on .0567% of the entries and on order 7e-3 + tm = self.torch_model.model.vision_tower + nm = self.bonsai_model.vision_tower + + t_inputs = self._make_torch_input() + n_inputs = self._make_bonsai_input(t_inputs) + tx = t_inputs["pixel_values"] + nx = n_inputs["pixel_values"] + + with torch.no_grad(): + ty = tm(tx).last_hidden_state + ny = nm(nx) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-2, atol=1e-2) + + # Language tests + @unittest.skipIf(SKIP_DONE, "Done") + def test_text_embedding(self): + self._upgrade_dtypes() + tm = self.torch_model.model.language_model.embed_tokens + nm = self.bonsai_model.embed_tokens + + torch.testing.assert_close(torch.tensor(nm.weight.embedding.value), tm.weight.cpu()) + torch.testing.assert_close(torch.tensor(nm.embed_scale), tm.embed_scale.cpu()) + + t_inputs = self._make_torch_input() + n_inputs = self._make_bonsai_input(t_inputs) + tx = t_inputs["input_ids"] + nx = n_inputs["input_ids"] + + with torch.no_grad(): + ty = tm(tx) + ny = nm(nx) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy()) + + @unittest.skipIf(SKIP_DONE, "Done") + def test_attn_projs(self): + tm = self.torch_model.model.language_model.layers[0].self_attn + nm = self.bonsai_model.language_model.layers[0].self_attn + + tx = torch.randn((1, 281, 2560), device=self.torch_device) + nx = tx.detach().cpu().numpy() + + ty = tm.q_proj(tx) + ny = nm.q_proj(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="q") + + ty = tm.k_proj(tx) + ny = nm.k_proj(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="k") + + ty = tm.v_proj(tx) + ny = nm.v_proj(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="v") + + tx = torch.randn((1, 281, 2048), device=self.torch_device) + nx = tx.detach().cpu().numpy() + ty = tm.o_proj(tx) + ny = nm.o_proj(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="o") + + @unittest.skipIf(SKIP_DONE, "Done") + def test_attn_norms(self): + tm = self.torch_model.model.language_model.layers[0].self_attn + nm = self.bonsai_model.language_model.layers[0].self_attn + + tx = torch.randn((1, 281, 2048), device=self.torch_device).reshape(1, 281, -1, 256) + nx = tx.detach().cpu().numpy() + + np.testing.assert_allclose( + nm.q_norm.scale.value, tm.q_norm.weight.detach().cpu().numpy(), err_msg="q_norm weights" + ) + + ty = tm.q_norm(tx) + ny = nm.q_norm(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5, err_msg="q") + + tx = torch.randn((1, 281, 1024), device=self.torch_device).reshape(1, 281, -1, 256) + nx = tx.detach().cpu().numpy() + + ty = tm.k_norm(tx) + ny = nm.k_norm(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5, err_msg="k") + + @unittest.skipIf(SKIP_DONE, "Done") + def test_sin_cos(self): + batch_size, seq_len, dim = 2, 10, 256 + hidden_states = torch.ones((batch_size, seq_len, dim)) + jp = jnp.stack([jnp.arange(seq_len), jnp.arange(seq_len)]) + + # local uses default + rt = self.bonsai_config.text_config.rope_slide_theta + js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=1.0) + rot_emb = self.torch_model.model.language_model.rotary_emb_local + tc, ts = rot_emb(hidden_states, torch.tensor(jp)) + tc, ts = tc[:, :, : dim // 2], ts[:, :, : dim // 2] + torch.testing.assert_close(torch.tensor(js), ts) + torch.testing.assert_close(torch.tensor(jc), tc) + + # global uses linear + rt = self.bonsai_config.text_config.rope_full_theta + js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=8.0) + rot_emb = self.torch_model.model.language_model.rotary_emb + tc, ts = rot_emb(hidden_states, torch.tensor(jp)) + tc, ts = tc[:, :, : dim // 2], ts[:, :, : dim // 2] + torch.testing.assert_close(torch.tensor(js), ts) + torch.testing.assert_close(torch.tensor(jc), tc) + + @unittest.skipIf(SKIP_DONE, "Done") + def test_text_decoder_layer(self): + first_t_inputs = self._make_torch_input() + start_t_inputs = self._process_torch_inputs(**first_t_inputs) + + for test_layer in trange(34): + # Models + tm = self.torch_model.model.language_model.layers[test_layer] + nm = self.bonsai_model.language_model.layers[test_layer] + attn_type = tm.attention_type + + # Inputs + t_inputs = self._process_torch_inputs_for_decoder_text_model(attn_type, **start_t_inputs) + nx = t_inputs["hidden_states"].detach().cpu().numpy() + batch_size, num_tokens, _ = nx.shape + nnx_cache = self._init_nnx_cache( + batch_size=batch_size, token_len=num_tokens, generate_steps=1, dtype=jnp.float32 + ) + n_tti = first_t_inputs["token_type_ids"].detach().cpu().numpy() + + if attn_type == "full_attention": + mask = modeling.make_causal_mask(nnx_cache[test_layer], n_tti) + else: + mask = modeling.make_window_mask(nnx_cache[test_layer], n_tti, 1024) + + # run models + ty = tm(**t_inputs) + ny = nm(nx, nnx_cache[test_layer], jnp.ones((batch_size, num_tokens)), mask=mask) + + t_inputs["hidden_states"] = ty[0] + + found_exception = False + try: + np.testing.assert_allclose( + ny, ty[0].detach().cpu().numpy(), rtol=5e-3, atol=5e-3, err_msg=f"{test_layer}" + ) + except Exception as e: + print(e) + found_exception = True + assert not found_exception, "FOUND EXCEPTION" + + # multi modal tests + + @unittest.skipIf(SKIP_DONE, "Done") + def test_multi_modal_projector(self): + t_inputs = self._make_torch_input() + tm = self.torch_model.model + nm = self.bonsai_model.multi_modal_projector + + tx = tm.vision_tower(t_inputs["pixel_values"]).last_hidden_state + nx = tx.detach().cpu().numpy() + + ty = tm.multi_modal_projector(tx) + ny = nm(nx) + + torch.testing.assert_close(torch.tensor(ny), ty, rtol=1e-4, atol=1e-4) + + @unittest.skipIf(SKIP_DONE, "Done") + def test_text_image_merge(self): + nm = self.bonsai_model + t_inputs = self._make_torch_input() + t_out = self._process_torch_inputs(**t_inputs) + + # answer is input_embeds + t_ans = t_out["inputs_embeds"] + + tmp = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_text = nm.embed_tokens(tmp) + + # return + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_img = nm.vision_tower(n_img) + n_img = nm.multi_modal_projector(n_img) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + n_ans = modeling.batched_merge_modalities(n_img, n_text, n_tti) + + np.testing.assert_allclose(n_ans, t_ans.detach().cpu().numpy(), rtol=1e-3, atol=1e-3) + + @unittest.skipIf(SKIP_DONE, "Done") + def test_text_layers_in_order(self): + first_t_inputs = self._make_torch_input() + start_t_inputs = self._process_torch_inputs(**first_t_inputs) + ny, ty = None, None + + for test_layer in trange(34): + # Models + tm = self.torch_model.model.language_model.layers[test_layer] + nm = self.bonsai_model.language_model.layers[test_layer] + attn_type = tm.attention_type + + # Inputs + t_inputs = self._process_torch_inputs_for_decoder_text_model(attn_type, **start_t_inputs) + if ty is not None: + t_inputs["hidden_states"] = ty[0] + if ny is None: + nx = t_inputs["hidden_states"].detach().cpu().numpy() + else: + nx = ny + batch_size, num_tokens, _ = nx.shape + nnx_cache = self._init_nnx_cache( + batch_size=batch_size, token_len=num_tokens, generate_steps=1, dtype=jnp.float32 + ) + n_tti = first_t_inputs["token_type_ids"].detach().cpu().numpy() + + if attn_type == "full_attention": + mask = modeling.make_causal_mask(nnx_cache[test_layer], n_tti) + else: + mask = modeling.make_window_mask(nnx_cache[test_layer], n_tti, 1024) + + # run models + ty = tm(**t_inputs) + ny = nm(nx, nnx_cache[test_layer], jnp.ones((batch_size, num_tokens)), mask=mask) + + found_exception = False + try: + np.testing.assert_allclose(ny, ty[0].detach().cpu().numpy(), rtol=1, atol=1, err_msg=f"{test_layer}") + except Exception as e: + print(e) + found_exception = True + assert not found_exception, "FOUND EXCEPTION" + + @unittest.skipIf(SKIP_DONE, "Done") + def test_masks(self): + # Make a really long input so we can test the sliding window + # This only tests for the pre-fill stage + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail." + "hello " * 1500}, + ], + }, + ] + + t_inputs = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + t_inputs["pixel_values"] = t_inputs["pixel_values"].to(dtype=torch.float32) + + batch_size, num_tokens = t_inputs["input_ids"].shape + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + gen_steps = 10 + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, gen_steps) + n_mask = modeling.make_causal_mask(cache[0], n_tti) + + # Full attention + t_inputs = self._process_torch_inputs(**t_inputs) + t_mask = t_inputs["attention_mask"]["full_attention"] + size_for_comp = t_mask.shape[-1] + + np.testing.assert_allclose(n_mask[:, :, :, :size_for_comp], t_mask.detach().cpu().numpy()) + + # Sliding attention + t_mask = t_inputs["attention_mask"]["sliding_attention"] + n_mask = modeling.make_window_mask(cache[0], n_tti, self.bonsai_config.text_config.sliding_window) + + np.testing.assert_allclose(n_mask[:, :, :, :size_for_comp], t_mask.detach().cpu().numpy()) + + @unittest.skip("Skipping - this test is just to observe errors over full model evaluation") + def test_full_in_order(self): + tm = self.torch_model.model + nm = self.bonsai_model + + # Torch inputs + t_inputs = self._make_torch_input() + + # NNX inputs + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + # Get masks + n_causal_mask = modeling.make_causal_mask(cache[0], n_tti) + n_sliding_mask = modeling.make_window_mask(cache[0], n_tti, self.bonsai_config.text_config.sliding_window) + + # text embeds + t_inputs_embeds = tm.language_model.embed_tokens(t_inputs["input_ids"]) + n_inputs_embeds = nm.embed_tokens(n_text) + np.testing.assert_allclose(n_inputs_embeds, t_inputs_embeds.detach().cpu().numpy(), err_msg="text emb") + + # Vision part + t_vis = tm.vision_tower(t_inputs["pixel_values"]).last_hidden_state + n_vis = nm.vision_tower(n_img) + # Mismatched elements: 4608354 / 4718592 (97.7%) + # Max absolute difference among violations: 0.00756264 + # Max relative difference among violations: 15.521739 + np.testing.assert_allclose(n_vis, t_vis.detach().cpu().numpy(), rtol=1e-3, atol=1e-3, err_msg="vis tower") + + # MM Proj part + t_img_feat = tm.multi_modal_projector(t_vis) + n_img_feat = nm.multi_modal_projector(n_vis) + # Mismatched elements: 648574 / 655360 (99%) + # Max absolute difference among violations: 0.00063944 + # Max relative difference among violations: 20.392141 + np.testing.assert_allclose( + n_img_feat, t_img_feat.detach().cpu().numpy(), rtol=1e-3, atol=1e-3, err_msg="mm proj" + ) + + # Merging part + special_image_mask = tm.get_placeholder_mask( + t_inputs["input_ids"], inputs_embeds=t_inputs_embeds, image_features=t_img_feat + ) + t_inputs_embeds = t_inputs_embeds.masked_scatter(special_image_mask, t_img_feat) + n_inputs_embeds = modeling.batched_merge_modalities(n_img_feat, n_inputs_embeds, n_tti) + # Mismatched elements: 648574 / 719360 (90.2%) + # Max absolute difference among violations: 0.00063944 + # Max relative difference among violations: 20.392141 + np.testing.assert_allclose( + n_inputs_embeds, t_inputs_embeds.detach().cpu().numpy(), rtol=1e-3, atol=1e-3, err_msg="merge" + ) + + # NOTE: Text part in order + t_inputs["output_hidden_states"] = True + t_text_inputs = self._process_torch_inputs(**t_inputs) + t_hidden_states = tm.language_model(**t_text_inputs).hidden_states + assert len(t_hidden_states) - 1 == len(nm.language_model.layers), ( + f"{len(t_hidden_states)} vs {len(nm.language_model.layers)}" + ) + + # check inputs + nx = n_inputs_embeds + + n_hidden_states = [] + for i, layer in enumerate(nm.language_model.layers): + attn_type = tm.language_model.layers[i].attention_type + n_mask = n_causal_mask if attn_type == "full_attention" else n_sliding_mask + n_hidden_states.append(nx) + nx = layer(nx, cache[i], segment_ids, n_mask) + nx = nm.language_model.norm(nx) + n_hidden_states.append(nx) + + for i, (nval, tval) in enumerate(zip(n_hidden_states, t_hidden_states)): + try: + np.testing.assert_allclose(nval, tval.detach().cpu().numpy(), err_msg=f"text {i}") + except Exception as e: + print(e) + found_error = True + assert not found_error, "Found errors in text decoder layers" + # NOTE: some errors are expected here since errors compound with layer + + # @unittest.skipIf(SKIP_DONE, "Done") + def test_full(self): + tm = self.torch_model + nm = self.bonsai_model + + t_inputs = self._make_torch_input() + + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + ny = nm(n_text, n_img, cache, segment_ids, n_tti) + ty = tm(**t_inputs) + + torch.testing.assert_close(torch.tensor(ny), ty.logits, rtol=5e-2, atol=5e-2) + + @unittest.skip("TODO") + def test_full_batched(self): + tm = self.torch_model + nm = self.bonsai_model + + t_inputs = self._make_torch_input() + + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + # Test simple batching + n_img = jnp.concat([n_img, n_img]) + n_text = jnp.concat([n_text, n_text]) + n_tti = jnp.concat([n_tti, n_tti]) + + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + ny = nm(n_text, n_img, cache, segment_ids, n_tti) + ty = tm(**t_inputs) + + # for i in range(2): + torch.testing.assert_close(torch.tensor(ny)[0:1], ty.logits, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(torch.tensor(ny)[1:2], ty.logits, rtol=5e-2, atol=5e-2) + + raise NotImplementedError("Need to get more complex batched inputs working") + # When doing batching, prompts have >= 0 images (not all same) -> change batched_merge_modalities + # for this, we might also need to keep track of where images came from + # We also need to update the left padding to deal with different padding for each prompt + + +if __name__ == "__main__": + absltest.main() diff --git a/bonsai/models/gemma3/tests/test_sharding_gemma3.py b/bonsai/models/gemma3/tests/test_sharding_gemma3.py new file mode 100644 index 0000000..e9405bb --- /dev/null +++ b/bonsai/models/gemma3/tests/test_sharding_gemma3.py @@ -0,0 +1,96 @@ +import os +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from absl.testing import absltest +from huggingface_hub import snapshot_download +from jax import P, make_mesh, set_mesh +from jax.sharding import AxisType +from jax.typing import DTypeLike +from tqdm import trange +from transformers import AutoProcessor + +from bonsai.models.gemma3 import modeling, params + +# artificial cpu devices +jax.config.update("jax_num_cpu_devices", 2) + + +class TestSharding(absltest.TestCase): + # using this for faster testing. This way we can avoid reloading the model. + # Make sure not to modify the Gemma3 model in inconsistent ways between tests. + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_name: str = "google/gemma-3-4b-it" + # self.model_name: str = "google/gemma-3-270m" # This is text only + access_token = os.environ["HF_TOKEN"] + cls.processor = AutoProcessor.from_pretrained(cls.model_name, token=access_token, use_fast=False) + cls.torch_device = "cpu" + + fsdp, tp = modeling.ShardMode.FSDP.value, modeling.ShardMode.TP.value + + cls.mesh = jax.make_mesh(((2, 1)), (fsdp, tp), axis_types=(AxisType.Explicit, AxisType.Explicit)) + jax.set_mesh(cls.mesh) + + cls.bonsai_config = modeling.ModelConfig.gemma3_4b() + model_ckpt_path = snapshot_download(cls.model_name) + cls.bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path, cls.bonsai_config) + + cls.batch_size = 1 + cls.cache_size, cls.gen_steps = 512, 10 + + def _make_torch_input(self): + # returns model inputs: + # KEY SHAPE DTYPE + # input_ids torch.Size([1, 281]) int64 + # attention_mask torch.Size([1, 281]) int64 + # token_type_ids torch.Size([1, 281]) int64 + # pixel_values torch.Size([1, 3, 896, 896]) bfloat16 -> float32 + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + + out = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + out["pixel_values"] = out["pixel_values"].to(dtype=torch.float32) + + return {k: v.to(device=self.torch_device) for k, v in out.items()} + + def test_full(self): + nm = self.bonsai_model + + t_inputs = self._make_torch_input() + + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + # Test simple batching + n_img = jnp.concat([n_img, n_img]) + n_text = jnp.concat([n_text, n_text]) + n_tti = jnp.concat([n_tti, n_tti]) + + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + nm(n_text, n_img, cache, segment_ids, n_tti) + + +if __name__ == "__main__": + absltest.main() diff --git a/pyproject.toml b/pyproject.toml index ec325a3..029cd39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ qwen3 = [] resnet50 = [] sam2 = ["pillow>=11.3.0"] vgg19 = ["h5py", "keras_hub", "tensorflow"] +gemma3 = ["sentencepiece"] dev = [ "xprof",