From 5f7d603b64f83bbf0576364216f05479c15b136e Mon Sep 17 00:00:00 2001 From: James Chapman Date: Tue, 9 Dec 2025 17:10:18 -0800 Subject: [PATCH 1/3] Gemma3 initial commit --- bonsai/models/gemma3/README.md | 37 + bonsai/models/gemma3/modeling.py | 552 ++++++++++++++ bonsai/models/gemma3/params.py | 267 +++++++ bonsai/models/gemma3/tests/run_model.py | 121 ++++ .../gemma3/tests/test_outputs_gemma3.py | 679 ++++++++++++++++++ pyproject.toml | 1 + 6 files changed, 1657 insertions(+) create mode 100644 bonsai/models/gemma3/README.md create mode 100644 bonsai/models/gemma3/modeling.py create mode 100644 bonsai/models/gemma3/params.py create mode 100644 bonsai/models/gemma3/tests/run_model.py create mode 100644 bonsai/models/gemma3/tests/test_outputs_gemma3.py diff --git a/bonsai/models/gemma3/README.md b/bonsai/models/gemma3/README.md new file mode 100644 index 00000000..8b5cbd04 --- /dev/null +++ b/bonsai/models/gemma3/README.md @@ -0,0 +1,37 @@ +# Qwen3 in JAX + +This directory contains a pure JAX implementation of the [Gemma3 model](https://deepmind.google/models/gemma/gemma-3/), using the [Flax NNX](https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html) API. Note that you need an access token to download the model weights. + + +**This is currently in progress but passing numerics checks. Working on cleaning up the code and optimizing before the official PR.** + + +## Model Configuration Support Status + + +### Running this model + + +```sh +python3 -m bonsai.models.gemma3.tests.run_model +``` + + +## How to contribute to this model + +### Remaining Tasks + +1. Implement KV caching to speed up inference +2. JIT Compile forward pass +3. Finish the `run_model.py` example. Add timing and profiling. +4. Optimize based on the profiling. +5. Implement sharding. +6. Get the `lm_head` from the weights. +7. Update to include other model sizes + + +### Implementation Notes +The implementation matches the HF one pretty well. To get KV caching working, we have to pad things. We also have to pad the token_type_ids on the right with 0's. + + + diff --git a/bonsai/models/gemma3/modeling.py b/bonsai/models/gemma3/modeling.py new file mode 100644 index 00000000..8001dd5f --- /dev/null +++ b/bonsai/models/gemma3/modeling.py @@ -0,0 +1,552 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass, field +from enum import Enum +from functools import partial +from typing import Any, Optional, Tuple, TypeAlias + +import flax +import jax +import jax.numpy as jnp +import jax.sharding as shd +from flax import nnx +from jax.interpreters import pxla +from jaxtyping import Array, Float + +_K_MASK = jax._src.nn.functions._get_large_negative(jax.numpy.float32).item() + + +class AttentionType(Enum): + FULL = "full_attention" + SLIDE = "sliding_attention" + + +def _make_attn_types(): + # Fix this (5x slide 1x full) x 5 + (4x slide) + return [AttentionType.FULL if i % 6 == 5 else AttentionType.SLIDE for i in range(34)] + + +@dataclass +class VisionConfig: + attention_dropout: float = 0.0 + hidden_act: str = "gelu_pytorch_tanh" + hidden_size: int = 1152 + image_size: int = 896 + intermediate_size: int = 4304 + layer_norm_eps: float = 1e-6 + num_attention_heads: int = 16 + num_channels: int = 3 + num_hidden_layers: int = 27 + patch_size: int = 14 + vision_use_head: bool = False + + +@dataclass +class TextConfig: + _sliding_window_pattern: int = 6 + attention_bias: bool = False + attention_dropout: float = 0.0 + attn_logit_softcapping: Optional[float] = None + final_logit_softcapping: Optional[float] = None + head_dim: int = 256 + hidden_activation: str = "gelu_pytorch_tanh" + hidden_size: int = 2560 + initializer_range: float = 0.02 + intermediate_size: int = 10240 + layer_types: list[AttentionType] = field(default_factory=lambda: _make_attn_types()) + max_position_embeddings: int = 131072 + num_attention_heads: int = 8 + num_hidden_layers: int = 34 + num_key_value_heads: int = 4 + query_pre_attn_scalar: int = 256 + rms_norm_eps: float = 1e-6 + rope_local_base_freq: float = 10000.0 + rope_scaling: dict[str, Any] = field(default_factory=lambda: {"factor": 8.0, "rope_type": "linear"}) + rope_theta: float = 1000000.0 + sliding_window: int = 1024 + use_cache: bool = True + vocab_size: int = 262208 + + +@dataclass +class ModelConfig: + vision_config: VisionConfig = field(default_factory=lambda: VisionConfig()) + text_config: TextConfig = field(default_factory=lambda: TextConfig()) + mm_tokens_per_image: int = 256 + boi_token_index: int = 255999 + dtype: str = "bfloat16" + eoi_token_index: int = 256000 + eos_token_id: list[int] = field(default_factory=lambda: [1, 106]) + image_token_index: int = 262144 + initializer_range: float = 0.02 + mm_tokens_per_image: int = 256 + + +## GENERAL + + +## VISION + + +# TODO: update to include interpolate_pos_encoding +class SiglipVisionEmbeddings(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.num_patches = (config.image_size // config.patch_size) ** 2 + self.patch_embedding = nnx.Conv( + config.num_channels, + config.hidden_size, + kernel_size=(config.patch_size,) * 2, + strides=(config.patch_size,) * 2, + padding="valid", + rngs=rngs, + ) + self.position_embedding = nnx.Embed(self.num_patches, config.hidden_size, rngs=rngs) + self.position_ids = jnp.expand_dims(jnp.arange(self.num_patches), 0) + + def __call__(self, pixel_values: Array): + patch_embeds = self.patch_embedding(pixel_values) + b, h, w, c = patch_embeds.shape + embeddings = patch_embeds.reshape((b, h * w, c)) + return embeddings + self.position_embedding(self.position_ids) + + +class SiglipAttention(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.k_proj = nnx.Linear(config.hidden_size, config.hidden_size, rngs=rngs) + self.v_proj = nnx.Linear(config.hidden_size, config.hidden_size, rngs=rngs) + self.q_proj = nnx.Linear(config.hidden_size, config.hidden_size, rngs=rngs) + self.out_proj = nnx.Linear(config.hidden_size, config.hidden_size, rngs=rngs) + + def __call__(self, x: Array, attn_mask: Array | None): + batch_size, seq_length, _ = x.shape + shape = (batch_size, seq_length, self.num_heads, self.head_dim) + q = self.q_proj(x).reshape(shape) + k = self.k_proj(x).reshape(shape) + v = self.v_proj(x).reshape(shape) + + attn = jax.nn.dot_product_attention(q, k, v, mask=attn_mask).reshape(x.shape) + return self.out_proj(attn) + + +class SiglipMLP(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.act = jax.nn.gelu + self.fc1 = nnx.Linear(config.hidden_size, config.intermediate_size, rngs=rngs) + self.fc2 = nnx.Linear(config.intermediate_size, config.hidden_size, rngs=rngs) + + def __call__(self, x: Array): + return self.fc2(self.act(self.fc1(x))) + + +class SiglipEncoderLayer(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.layer_norm1 = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.self_attn = SiglipAttention(config, rngs=rngs) + self.layer_norm2 = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.mlp = SiglipMLP(config, rngs=rngs) + + def __call__(self, x: Array, attn_mask: Array | None): + hidden = self.layer_norm1(x) + hidden = self.self_attn(hidden, attn_mask) + hidden = x + hidden + x = hidden + hidden = self.layer_norm2(hidden) + hidden = self.mlp(hidden) + return hidden + x + + +class SiglipEncoder(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.layers = nnx.List([SiglipEncoderLayer(config, rngs=rngs) for _ in range(config.num_hidden_layers)]) + + def __call__(self, x: Array, attn_mask: Array | None): + for l in self.layers: + x = l(x, attn_mask) + return x + + +# TODO: Skip for now since not in 4b, but test later +class SiglipMultiheadAttentionPoolingHead(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.probe = nnx.Param(nnx.initializers.normal(stddev=0.02)(rngs.params(), (1, 1, config.hidden_size))) + self.attention = nnx.MultiHeadAttention(config.num_attention_heads, config.hidden_size, rngs=rngs) + self.layernorm = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.mlp = SiglipMLP(config, rngs=rngs) + + def __call__(self, *args, **kwargs): + raise NotImplementedError("Not yet implemented") + + +class SiglipVisionTransformer(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + self.embeddings = SiglipVisionEmbeddings(config, rngs=rngs) + self.encoder = SiglipEncoder(config, rngs=rngs) + self.post_layernorm = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + def __call__(self, pixel_values: Array): + x = self.embeddings(pixel_values) + x = self.encoder(x, attn_mask=None) + x = self.post_layernorm(x) + if self.use_head: + x = self.head(x) + return x + + +## LANGUAGE + + +# from qwen3 +class LayerCache(nnx.Module): + def __init__(self, cfg: ModelConfig, batch_size: int, cache_size: int, dtype: jnp.dtype): + cache_shape = (batch_size, cache_size, cfg.num_key_value_heads, cfg.head_dim) + self.k_cache = nnx.Cache(jnp.zeros(cache_shape, dtype=dtype)) + self.v_cache = nnx.Cache(jnp.zeros(cache_shape, dtype=dtype)) + self.size = self.k_cache.shape[1] + self.start_ind = nnx.Variable(-1 * jnp.ones((batch_size,), dtype=jnp.int32)) + self.cur_ind = nnx.Variable(jnp.zeros((), dtype=jnp.int32)) # scalar for compute efficiency. + + +def shard(x: jnp.ndarray, s: Tuple[str, ...]): + mesh = pxla.thread_resources.env.physical_mesh + if mesh.empty or jax.devices()[0].platform == "cpu": + return x + return jax.lax.with_sharding_constraint(x, shd.NamedSharding(mesh, shd.PartitionSpec(*s))) + + +Cache: TypeAlias = list[LayerCache] + + +def init_cache( + cfg: ModelConfig, batch_size: int, token_len: int, generate_steps: int, dtype: jnp.dtype = jnp.bfloat16 +) -> Cache: + cache_size = 2 ** math.ceil(math.log2(max(token_len + generate_steps, 1))) # Pad for a sharding-friendly size. + return [ + LayerCache(cfg.text_config, batch_size, cache_size, dtype) for _ in range(cfg.text_config.num_hidden_layers) + ] + + +class Gemma3RMSNorm(nnx.Module): + def __init__(self, dim: int, eps: float, *, rngs: nnx.Rngs): + self.scale = nnx.Param(nnx.initializers.zeros_init()(rngs.params(), dim)) + self.eps = eps + + @jax.named_scope("rms_norm") + def __call__(self, x: Array) -> Array: + dtype = x.dtype + xf32 = x.astype(jnp.float32) + out = xf32 * jax.lax.rsqrt(jnp.square(xf32).mean(-1, keepdims=True) + self.eps) + out = out * (1.0 + self.scale.value.astype(jnp.float32)) + return out.astype(dtype) + + +class Gemma3TextScaledWordEmbedding(nnx.Module): + def __init__( + self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0, *, rngs: nnx.Rngs + ): + self.weight = nnx.Embed(num_embeddings, embedding_dim, rngs=rngs) + self.embed_scale = jnp.array(embed_scale, dtype=jnp.bfloat16).astype(jnp.float32) + + def __call__(self, input_ids: Array): + return self.weight(input_ids) * self.embed_scale + + +# below is from qwen3 + + +def _generate_pos_embeddings( + positions: jax.Array, + head_dim: int, + rope_theta: int = 1_000_000, + factor: float = 1.0, +) -> tuple[jax.Array, jax.Array]: + # Forked from: jax-llm-examples/qwen3/qwen3_jax/model.py;l=571 + fraction = jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim + timescale = rope_theta**fraction + rotational_frequency = 1.0 / timescale + rotational_frequency /= factor + # Use high-precision einsum to prevent catastrophic bfloat16 rounding (ex: 257→256), as sin(257) differs from sin(256). + sinusoid_inp = jnp.einsum("BT,k->BTk", positions, rotational_frequency, precision=jax.lax.Precision.HIGHEST) + return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) + + +def apply_rope(x: jax.Array, sin: jax.Array, cos: jax.Array) -> jax.Array: + assert x.ndim == 4 and sin.ndim == 3 and cos.ndim == 3 + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + # [B, T, head_dim] -> [B, h, T, head_dim] + sin, cos = sin[:, :, None, :], cos[:, :, None, :] + return jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1).astype(x.dtype) + + +def count_left_pads(x: jax.Array) -> int: + """Count left padding tokens.""" + return jnp.sum(jnp.cumsum(x != 0, axis=-1) == 0, -1) + + +def count_right_pads(x: jax.Array, pad_id) -> int: + result = jnp.where( + jnp.all(x == pad_id, axis=1), x.shape[1], jnp.argmin(jnp.flip(x == pad_id, axis=1).astype(jnp.int32), axis=1) + ) + return jnp.max(result) + + +def compute_positions_from_segment_ids(seg_ids: Array): + return jax.vmap(lambda row: jnp.where(row != 0, jnp.arange(seg_ids.shape[1]) - jnp.argmax(row), 2**30))(seg_ids) + + +## Above is from qwen3 + + +def repeat_kv(hidden_states: Array, n_rep: int): + b, t, kv_heads, head_dim = hidden_states.shape + hidden_states = jnp.expand_dims(hidden_states, axis=3) + hidden_states = jnp.repeat(hidden_states, repeats=n_rep, axis=3) + return hidden_states.reshape(b, t, kv_heads * n_rep, head_dim) + + +class Gemma3Attention(nnx.Module): + def __init__(self, config: TextConfig, layer_idx: int, *, rngs: nnx.Rngs): + self.config = config + self.layer_idx = layer_idx + self.use_sliding = config.layer_types[layer_idx] == AttentionType.SLIDE + self.num_kv_heads = self.config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.q_proj = nnx.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, use_bias=config.attention_bias, rngs=rngs + ) + self.k_proj = nnx.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, use_bias=config.attention_bias, rngs=rngs + ) + self.v_proj = nnx.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, use_bias=config.attention_bias, rngs=rngs + ) + self.o_proj = nnx.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, use_bias=config.attention_bias, rngs=rngs + ) + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, rngs=rngs) + self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, rngs=rngs) + + self.rope_theta = config.rope_local_base_freq if self.use_sliding else config.rope_theta + self.factor = 1.0 if self.use_sliding else 8.0 + + self.n_rep = config.num_attention_heads // config.num_key_value_heads + self.scale = config.head_dim**-0.5 + + def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, mask: Array | None) -> Array: + # get projections + new_shape = (*x.shape[:-1], -1, self.head_dim) + q = self.q_norm(self.q_proj(x).reshape(new_shape)) + k = self.k_norm(self.k_proj(x).reshape(new_shape)) + v = self.v_proj(x).reshape(new_shape) + + # Apply rope + left_pads = count_left_pads(segment_ids) + cache.start_ind.value = jnp.where(cache.start_ind.value < 0, left_pads, cache.start_ind.value) + position_ids = compute_positions_from_segment_ids(segment_ids) + cache.cur_ind.value + sin, cos = _generate_pos_embeddings(position_ids, self.head_dim, self.rope_theta, factor=self.factor) + q = apply_rope(q, sin, cos) + k = apply_rope(k, sin, cos) + + # Update cache + slice_indices = (0, cache.cur_ind.value, 0, 0) + cache.v_cache.value = jax.lax.dynamic_update_slice(cache.v_cache.value, v, slice_indices) + cache.k_cache.value = jax.lax.dynamic_update_slice(cache.k_cache.value, k, slice_indices) + t = q.shape[1] + cache.cur_ind.value += x.shape[1] + + # TODO: Need to do this with the kv cache next + k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) + qkv = jax.nn.dot_product_attention(q, k, v, is_causal=False, mask=mask[:, :, :, :t], scale=self.scale) + # k, v = repeat_kv(cache.k_cache.value, self.n_rep), repeat_kv(cache.v_cache.value, self.n_rep) + # qkv = jax.nn.dot_product_attention(q, k, v, is_causal=False, mask=mask[:, :, :, :t], scale=self.scale) + + cache.cur_ind.value = cache.cur_ind.value + t + return self.o_proj(qkv.reshape(*x.shape[:-1], -1)) + + +class Gemma3MLP(nnx.Module): + def __init__(self, config: TextConfig, *, rngs: nnx.Rngs): + self.config = config + self.gate_proj = nnx.Linear(config.hidden_size, config.intermediate_size, use_bias=False, rngs=rngs) + self.up_proj = nnx.Linear(config.hidden_size, config.intermediate_size, use_bias=False, rngs=rngs) + self.down_proj = nnx.Linear(config.intermediate_size, config.hidden_size, use_bias=False, rngs=rngs) + + def __call__(self, x: Array): + return self.down_proj(jax.nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + + +class Gemma3DecoderLayer(nnx.Module): + def __init__(self, config: TextConfig, layer_idx: int, *, rngs: nnx.Rngs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx, rngs=rngs) + self.mlp = Gemma3MLP(config, rngs=rngs) + self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + + def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, mask: Array | None) -> Array: + res = x + x = self.input_layernorm(x) + x = self.self_attn(x, cache, segment_ids, mask=mask) + x = self.post_attention_layernorm(x) + x = res + x + res = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + return x + res + + @property + def head_dim(self): + return self.o_proj.shape[1] + + +class Gemma3TextModel(nnx.Module): + def __init__(self, config: TextConfig, *, rngs: nnx.Rngs): + self.config = config + # TODO: Move this out of this class into the larger class + self.embed_tokens = Gemma3TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + "self.padding_idx", + embed_scale=self.config.hidden_size**0.5, + rngs=rngs, + ) + self.layers = nnx.List( + [Gemma3DecoderLayer(config, layer_idx, rngs=rngs) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps, rngs=rngs) + + def __call__(self, x, cache: Cache, segment_ids: Array, sliding_mask: Array | None, causal_mask: Array | None): + # x = self.embed_tokens(x) # done in higher layer now + for lt, c, layer in zip(self.config.layer_types, cache, self.layers): + mask = sliding_mask if lt == AttentionType.SLIDE else causal_mask + x = layer(x, c, segment_ids, mask) + return self.norm(x) + + +class Gemma3MultiModalProjector(nnx.Module): + def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): + self.config = config + vhs = config.vision_config.hidden_size + ths = config.text_config.hidden_size + eps = config.vision_config.layer_norm_eps + self.mm_input_projection_weight = nnx.Param(jnp.zeros((vhs, ths)), rngs=rngs) + self.mm_soft_emb_norm = Gemma3RMSNorm(vhs, eps=eps, rngs=rngs) + self.patches_per_img = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_img // self.tokens_per_side + + def __call__(self, vision_outputs: Array) -> Array: + b, _, t = vision_outputs.shape + vision_outputs = vision_outputs.swapaxes(1, 2).reshape(b, t, self.patches_per_img, self.patches_per_img) + # TODO: update this to get rid of the None and 0. + # Might have to write my own avg pool. + x = flax.linen.avg_pool( + vision_outputs[:, :, :, :, None], + window_shape=(1, 1, self.kernel_size, self.kernel_size), + strides=(1, 1, self.kernel_size, self.kernel_size), + )[:, :, :, :, 0] + x = x.reshape(b, t, -1) + x = x.swapaxes(1, 2) + x = self.mm_soft_emb_norm(x) + x = jnp.matmul(x, self.mm_input_projection_weight.value) + return x.astype(vision_outputs.dtype) + + +# def make_causal_mask(cache_layer: LayerCache, token_type_ids): +# pass + + +def make_causal_mask(b: int, t: int, token_type_ids: Array): + my_mask = nnx.make_causal_mask(jnp.ones((b, t))) + tti = token_type_ids.astype(jnp.bool_) + or_mask = tti[:, None, None, :] & tti[:, None, :, None] + my_mask = my_mask.astype(jnp.bool_) | or_mask + return my_mask + + +def make_window_mask(b: int, t: int, token_type_ids: Array, slide_size: int): + my_mask = make_causal_mask(b, t, token_type_ids) + tmp = jnp.arange(my_mask.shape[-1]) + slide = tmp[:, None] - tmp[None, :] < slide_size + return my_mask & slide + + +def merge_modalities(img_emb: Array, text_emb: Array, token_mask: Array) -> Array: + # This function fills the image tokens into the text_emb sequence + # The token_mask tells us where the image tokens are (0 for text, 1 for image) + # image_emb is (Li, D) + # text_emb is (Lt, D) + # token_mask is (Lt) + # We have Li < Lt + img_indices = jnp.cumsum(token_mask) - 1 + safe_indices = jnp.clip(img_indices, 0, img_emb.shape[0] - 1) + aligned_images = img_emb[safe_indices] + return jnp.where(token_mask[:, None], aligned_images, text_emb) + + +def batched_merge_modalities(img_emb: Array, text_emb: Array, token_mask: Array) -> Array: + # image_emb is (B, Li, D) + # text_emb is (B, Lt, D) + # token_mask is (B, Lt) + # We have Li < Lt + return jax.vmap(merge_modalities)(img_emb, text_emb, token_mask) + + +class Gemma3Model(nnx.Module): + def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): + self.vision_tower = SiglipVisionTransformer(cfg.vision_config, rngs=rngs) + self.multi_modal_projector = Gemma3MultiModalProjector(cfg, rngs=rngs) + self.language_model = Gemma3TextModel(cfg.text_config, rngs=rngs) + + def __call__( + self, input_ids: Array, pixel_values: Array, cache: Cache, segment_ids: Array, token_type_ids: Array + ) -> Array: + causal_mask = make_causal_mask(input_ids.shape[0], input_ids.shape[1], token_type_ids) + sliding_mask = make_causal_mask(input_ids.shape[0], input_ids.shape[1], token_type_ids) + + inputs_embeds = self.language_model.embed_tokens(input_ids) + + # Merge text and images + if pixel_values is not None: + vision_outputs = self.vision_tower(pixel_values) + image_features = self.multi_modal_projector(vision_outputs) + + image_features = image_features.astype(inputs_embeds.dtype) + inputs_embeds = batched_merge_modalities(image_features, inputs_embeds, token_type_ids) + + out = self.language_model(inputs_embeds, cache, segment_ids, sliding_mask, causal_mask) + return out + + +# TODO: Implement a jitted forward method diff --git a/bonsai/models/gemma3/params.py b/bonsai/models/gemma3/params.py new file mode 100644 index 00000000..29ca0f51 --- /dev/null +++ b/bonsai/models/gemma3/params.py @@ -0,0 +1,267 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Parameter helpers for bonsai.models.gemma3. + +Add functions to load or convert pretrained checkpoints and to return +default configuration values used by the model implementation. +""" + +import logging +import re +from enum import Enum + +import jax +import safetensors.flax as safetensors +from etils import epath +from flax import nnx + +from bonsai.models.gemma3 import modeling as model_lib + + +class Transform(Enum): + """ + Specifies default transformation types for model parameter names. + """ + + DEFAULT = None + BIAS = None + LINEAR = ((1, 0), None) + CONV2D = ((2, 3, 1, 0), None) + EMBED = None + + +# TODO: Need to get lm_head. It currently isn't used. +def _get_key_and_transform_mapping(): + # Mapping st_keys -> (nnx_keys, (permute_rule, reshape_rule)). + return { + r"^language_model\.model\.embed_tokens\.weight$": ( + r"language_model\.embed_tokens\.weight\.embedding", + Transform.EMBED, + ), + r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight$": ( + r"language_model\.layers\.\1\.input_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight$": ( + r"language_model\.layers\.\1\.mlp\.down_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight$": ( + r"language_model\.layers\.\1\.mlp\.gate_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight$": ( + r"language_model\.layers\.\1\.mlp\.up_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight$": ( + r"language_model\.layers\.\1\.post_attention_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.post_feedforward_layernorm\.weight$": ( + r"language_model\.layers\.\1\.post_feedforward_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.pre_feedforward_layernorm\.weight$": ( + r"language_model\.layers\.\1\.pre_feedforward_layernorm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.k_norm\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.k_norm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.k_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.o_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.q_norm\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.q_norm\.scale", + Transform.DEFAULT, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.q_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.weight$": ( + r"language_model\.layers\.\1\.self_attn\.v_proj\.kernel", + Transform.LINEAR, + ), + r"^language_model\.model\.norm\.weight$": (r"language_model\.norm\.scale", Transform.DEFAULT), + r"^multi_modal_projector\.mm_input_projection_weight$": ( + r"multi_modal_projector\.mm_input_projection_weight", + Transform.DEFAULT, + ), + r"^multi_modal_projector\.mm_soft_emb_norm\.weight$": ( + r"multi_modal_projector\.mm_soft_emb_norm\.scale", + Transform.DEFAULT, + ), + r"^vision_tower\.vision_model\.embeddings\.patch_embedding\.bias$": ( + r"vision_tower\.embeddings\.patch_embedding\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.embeddings\.patch_embedding\.weight$": ( + r"vision_tower\.embeddings\.patch_embedding\.kernel", + Transform.CONV2D, + ), + r"^vision_tower\.vision_model\.embeddings\.position_embedding\.weight$": ( + r"vision_tower\.embeddings\.position_embedding\.embedding", + Transform.EMBED, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.layer_norm(\d+)\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.layer_norm\2\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.layer_norm(\d+)\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.layer_norm\2\.scale", + Transform.DEFAULT, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc(\d+)\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.mlp\.fc\2\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc(\d+)\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.mlp\.fc\2\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.k_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.k_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.out_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.out_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.q_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.q_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.bias$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.v_proj\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.weight$": ( + r"vision_tower\.encoder\.layers\.\1\.self_attn\.v_proj\.kernel", + Transform.LINEAR, + ), + r"^vision_tower\.vision_model\.post_layernorm\.bias$": ( + r"vision_tower\.post_layernorm\.bias", + Transform.BIAS, + ), + r"^vision_tower\.vision_model\.post_layernorm\.weight$": ( + r"vision_tower\.post_layernorm\.scale", + Transform.DEFAULT, + ), + r"lm_head\.weight": ("lm_head.kernel", Transform.LINEAR), + } + + +def _st_key_to_jax_key(mapping, source_key): + """Map a safetensors key to exactly one JAX key & transform, else warn/error.""" + subs = [ + (re.sub(pat, repl, source_key), transform) + for pat, (repl, transform) in mapping.items() + if re.match(pat, source_key) + ] + if not subs: + logging.warning(f"No mapping found for key: {source_key!r}") + return None, None + if len(subs) > 1: + keys = [s for s, _ in subs] + raise ValueError(f"Multiple mappings found for {source_key!r}: {keys}") + return subs[0] + + +def _assign_weights(keys, tensor, state_dict, st_key, transform): + """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" + key, *rest = keys + if not rest: + if transform is not None: + permute, reshape = transform + if permute: + tensor = tensor.transpose(permute) + if reshape: + tensor = tensor.reshape(reshape) + if tensor.shape != state_dict[key].shape: + raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}") + state_dict[key] = tensor + else: + _assign_weights(rest, tensor, state_dict[key], st_key, transform) + + +def _stoi(s): + try: + return int(s) + except ValueError: + return s + + +# TODO: Update to include sharding +def create_gemma3_from_pretrained( + file_dir: str, + *, + mesh: jax.sharding.Mesh | None = None, +): + """ + Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model. + + Returns: + A flax.nnx.Model instance with loaded parameters. + """ + files = list(epath.Path(file_dir).expanduser().glob("*.safetensors")) + if not files: + raise ValueError(f"No safetensors found in {file_dir}") + + tensor_dict = {} + for f in files: + tensor_dict |= safetensors.load_file(f) + + gemma3 = model_lib.Gemma3Model(model_lib.ModelConfig(), rngs=nnx.Rngs(0)) + graph_def, abs_state = nnx.split(gemma3) + jax_state = abs_state.to_pure_dict() + + mapping = _get_key_and_transform_mapping() + for st_key, tensor in tensor_dict.items(): + jax_key, transform = _st_key_to_jax_key(mapping, st_key) + if jax_key is None: + continue + keys = [_stoi(k) for k in jax_key.split(r"\.")] + try: + _assign_weights(keys, tensor, jax_state, st_key, transform.value) + except KeyError as e: + print(f"Key error: {keys} at {e}") + except ValueError as e: + print(e) + except Exception as e: + print(keys) + raise e + + return nnx.merge(graph_def, jax_state) diff --git a/bonsai/models/gemma3/tests/run_model.py b/bonsai/models/gemma3/tests/run_model.py new file mode 100644 index 00000000..51b7f8a9 --- /dev/null +++ b/bonsai/models/gemma3/tests/run_model.py @@ -0,0 +1,121 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run a small inference example for Gemma3.""" + +import time + +import jax +import jax.numpy as jnp +import numpy as np +import tqdm +from flax import nnx + +from bonsai.models.gemma3 import modeling, params + +try: + from huggingface_hub import snapshot_download +except Exception: + snapshot_download = None + + +import os + +import torch +from transformers import AutoModel, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration + + +def make_input(processor, dtype=torch.float32, msg1=True): + if msg1: + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + }, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + else: + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + + out = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + out["pixel_values"] = out["pixel_values"].to(dtype=dtype) + + return out + + +def run_model(): + model_name: str = "google/gemma-3-4b-it" + cfg = modeling.ModelConfig() + access_token = os.environ["HF_TOKEN"] + processor = AutoProcessor.from_pretrained(model_name, token=access_token, use_fast=False) + model_ckpt_path = snapshot_download(model_name) + bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path) + + # # Dummy token ids + t_inputs = make_input(processor) + + t_lm_head = Gemma3ForConditionalGeneration.from_pretrained( + "google/gemma-3-4b-it", + dtype=torch.float32, + ).lm_head + + full_tokens = t_inputs["input_ids"] + + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + gen_steps = 30 + for i in tqdm.trange(gen_steps): + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(cfg, batch_size, num_tokens, 1, jnp.float32) + + out = bonsai_model(n_text, n_img, cache, segment_ids, n_tti) + out = torch.tensor(out) + + out = t_lm_head(out[:, -1:None, :]) + + out = torch.argmax(out, axis=-1) + + full_tokens = torch.concat([full_tokens, out], axis=-1) + n_text = full_tokens.detach().cpu().numpy() + n_tti = jnp.concatenate([n_tti, n_tti[:, -1:None]], axis=-1) + + out_tokens = processor.decode(full_tokens[0], skip_special_tokens=True) + print(out_tokens) + + +if __name__ == "__main__": + run_model() diff --git a/bonsai/models/gemma3/tests/test_outputs_gemma3.py b/bonsai/models/gemma3/tests/test_outputs_gemma3.py new file mode 100644 index 00000000..20768824 --- /dev/null +++ b/bonsai/models/gemma3/tests/test_outputs_gemma3.py @@ -0,0 +1,679 @@ +import os +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from absl.testing import absltest +from huggingface_hub import snapshot_download +from jax.typing import DTypeLike +from tqdm import trange +from transformers import AutoModel, AutoProcessor, AutoTokenizer, Gemma3Model, SiglipModel +from transformers.cache_utils import DynamicCache +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask +from transformers.models.gemma3 import Gemma3ForCausalLM +from transformers.models.gemma3.modeling_gemma3 import token_type_ids_mask_function + +from bonsai.models.gemma3 import modeling, params + + +class TestModuleForwardPasses(absltest.TestCase): + # using this for faster testing. This way we can avoid reloading the model. + # Make sure not to modify the Gemma3 model in inconsistent ways between tests. + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_name: str = "google/gemma-3-4b-it" + # self.model_name: str = "google/gemma-3-270m" # This is text only + access_token = os.environ["HF_TOKEN"] + cls.processor = AutoProcessor.from_pretrained(cls.model_name, token=access_token, use_fast=False) + cls.torch_device = "cpu" + + ## models + cls.torch_model = ( + AutoModel.from_pretrained(cls.model_name, dtype="auto") + .to(device=cls.torch_device, dtype=torch.float32) + .eval() + ) + cls.torch_config = cls.torch_model.config + + cls.bonsai_config = modeling.ModelConfig() + model_ckpt_path = snapshot_download(cls.model_name) + cls.bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path) + + cls.batch_size = 1 + cls.cache_size, cls.gen_steps = 512, 10 + + def _upgrade_dtypes(self): + self.bonsai_model.language_model.embed_tokens.weight.embedding.value = ( + self.bonsai_model.language_model.embed_tokens.weight.embedding.value.astype(jnp.float32) + ) + return + + def _make_torch_input(self): + # returns model inputs: + # KEY SHAPE DTYPE + # input_ids torch.Size([1, 281]) int64 + # attention_mask torch.Size([1, 281]) int64 + # token_type_ids torch.Size([1, 281]) int64 + # pixel_values torch.Size([1, 3, 896, 896]) bfloat16 -> float32 + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + + out = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + out["pixel_values"] = out["pixel_values"].to(dtype=torch.float32) + + return {k: v.to(device=self.torch_device) for k, v in out.items()} + + def _make_bonsai_input(self, torch_inputs): + out = dict() + for k, v in torch_inputs.items(): + tmp = v.detach().cpu().numpy() + if k == "pixel_values": + tmp = np.permute_dims(tmp, (0, 2, 3, 1)) + out[k] = tmp + return out + + # This should be correct for unbatched inputs + def _process_torch_inputs( + self, + input_ids=None, + pixel_values=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + token_type_ids=None, + cache_position=None, + inputs_embeds=None, + labels=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + **lm_kwargs, + ): + # Replace image id with PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.torch_config.image_token_id >= self.torch_config.text_config.vocab_size: + special_image_mask = input_ids == self.torch_config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.torch_model.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # Merge text and images + if pixel_values is not None: + image_features = self.torch_model.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.torch_model.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.torch_config.get_text_config(), + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized + # (e.g. compiled prefill) AND `pixel_values` are not provided. Determining prefill in that case requires + # checking data values, which is not compile-compatible. + is_prefill = ( + not use_cache + or past_key_values is None + or not past_key_values.is_initialized + or pixel_values is not None + ) + if token_type_ids is not None and is_prefill: + # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + + # First find where a new image block starts: 1 if image and previous not image + # The images cannot attend to future images, but can attend to all prev images and to itself + # bidirectionally + is_image = (token_type_ids == 1).to(cache_position.device) + new_image_start = is_image & ~torch.nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] + image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_ids = torch.where( + is_image, image_group_ids, torch.full_like(token_type_ids, -1, device=is_image.device) + ) + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device), image_group_ids, self.torch_config.mm_tokens_per_image + ) + + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + return dict( + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + # This should be correct for unbatched inputs + def _process_torch_inputs_for_decoder_text_model( + self, + attn_type, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + **kwargs, + ): + training = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None and not training: + past_key_values = DynamicCache(config=self.torch_model.config.text_config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + sliding_mask_kwargs = mask_kwargs.copy() + + # if self.config.use_bidirectional_attention: + # mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) + # sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window) + + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + position_embeddings_global = self.torch_model.language_model.rotary_emb(inputs_embeds, position_ids) + position_embeddings_local = self.torch_model.language_model.rotary_emb_local(inputs_embeds, position_ids) + return dict( + hidden_states=inputs_embeds, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask_mapping[attn_type], + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + def _init_nnx_cache(self, batch_size: int, token_len: int, generate_steps: int, dtype): + return modeling.init_cache( + cfg=self.bonsai_config, + batch_size=batch_size, + token_len=token_len, + generate_steps=generate_steps, + dtype=dtype, + ) + + # Vision tests + # @unittest.skip("Done") + def test_image_emb(self): + tm = self.torch_model.vision_tower.vision_model.embeddings + nm = self.bonsai_model.vision_tower.embeddings + + t_inputs = self._make_torch_input() + n_inputs = self._make_bonsai_input(t_inputs) + tx = t_inputs["pixel_values"] + nx = n_inputs["pixel_values"] + + with torch.no_grad(): + ty = tm(tx) + ny = nm(nx) + + # (1, 4096, 1152) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5) + + # @unittest.skip("Done") + def test_siglip_encoder_layer(self): + tm = self.torch_model.vision_tower.vision_model.encoder.layers[0] + nm = self.bonsai_model.vision_tower.encoder.layers[0] + + tx = torch.randn((1, 4096, 1152), device=self.torch_device) + nx = tx.detach().cpu().numpy() + + with torch.no_grad(): + ty = tm(tx, None) + ny = nm(nx, None) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4) + + # @unittest.skip("Done but double check") + def test_vision_model(self): + # only have deviations on .0567% of the entries and on order 7e-3 + tm = self.torch_model.vision_tower + nm = self.bonsai_model.vision_tower + + t_inputs = self._make_torch_input() + n_inputs = self._make_bonsai_input(t_inputs) + tx = t_inputs["pixel_values"] + nx = n_inputs["pixel_values"] + + with torch.no_grad(): + ty = tm(tx).last_hidden_state + ny = nm(nx) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-2, atol=1e-2) + + # Language tests + # @unittest.skip("Done") + def test_text_embedding(self): + self._upgrade_dtypes() + tm = self.torch_model.language_model.embed_tokens + nm = self.bonsai_model.language_model.embed_tokens + + torch.testing.assert_close(torch.tensor(nm.weight.embedding.value), tm.weight.cpu()) + torch.testing.assert_close(torch.tensor(nm.embed_scale), tm.embed_scale.cpu()) + + t_inputs = self._make_torch_input() + n_inputs = self._make_bonsai_input(t_inputs) + tx = t_inputs["input_ids"] + nx = n_inputs["input_ids"] + + with torch.no_grad(): + ty = tm(tx) + ny = nm(nx) + + np.testing.assert_allclose(ny, ty.detach().cpu().numpy()) + + # @unittest.skip("Done") + def test_attn_projs(self): + tm = self.torch_model.language_model.layers[0].self_attn + nm = self.bonsai_model.language_model.layers[0].self_attn + + tx = torch.randn((1, 281, 2560), device=self.torch_device) + nx = tx.detach().cpu().numpy() + + ty = tm.q_proj(tx) + ny = nm.q_proj(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="q") + + ty = tm.k_proj(tx) + ny = nm.k_proj(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="k") + + ty = tm.v_proj(tx) + ny = nm.v_proj(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="v") + + tx = torch.randn((1, 281, 2048), device=self.torch_device) + nx = tx.detach().cpu().numpy() + ty = tm.o_proj(tx) + ny = nm.o_proj(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="o") + + # @unittest.skip("Done") + def test_attn_norms(self): + tm = self.torch_model.language_model.layers[0].self_attn + nm = self.bonsai_model.language_model.layers[0].self_attn + + tx = torch.randn((1, 281, 2048), device=self.torch_device).reshape(1, 281, -1, 256) + nx = tx.detach().cpu().numpy() + + np.testing.assert_allclose( + nm.q_norm.scale.value, tm.q_norm.weight.detach().cpu().numpy(), err_msg="q_norm weights" + ) + + ty = tm.q_norm(tx) + ny = nm.q_norm(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5, err_msg="q") + + tx = torch.randn((1, 281, 1024), device=self.torch_device).reshape(1, 281, -1, 256) + nx = tx.detach().cpu().numpy() + + ty = tm.k_norm(tx) + ny = nm.k_norm(nx) + np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5, err_msg="k") + + # @unittest.skip("Done") + def test_sin_cos(self): + batch_size, seq_len, dim = 2, 10, 256 + hidden_states = torch.ones((batch_size, seq_len, dim)) + jp = jnp.stack([jnp.arange(seq_len), jnp.arange(seq_len)]) + + # local uses default + rt = self.bonsai_config.text_config.rope_local_base_freq + js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=1.0) + rot_emb = self.torch_model.language_model.rotary_emb_local + tc, ts = rot_emb(hidden_states, torch.tensor(jp)) + tc, ts = tc[:, :, : dim // 2], ts[:, :, : dim // 2] + torch.testing.assert_close(torch.tensor(js), ts) + torch.testing.assert_close(torch.tensor(jc), tc) + + # global uses linear + rt = self.bonsai_config.text_config.rope_theta + js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=8.0) + rot_emb = self.torch_model.language_model.rotary_emb + tc, ts = rot_emb(hidden_states, torch.tensor(jp)) + tc, ts = tc[:, :, : dim // 2], ts[:, :, : dim // 2] + torch.testing.assert_close(torch.tensor(js), ts) + torch.testing.assert_close(torch.tensor(jc), tc) + + # @unittest.skip("Done") + def test_text_decoder_layer(self): + start_t_inputs = self._make_torch_input() + start_t_inputs = self._process_torch_inputs(**start_t_inputs) + + for test_layer in trange(34): + # Models + tm = self.torch_model.language_model.layers[test_layer] + nm = self.bonsai_model.language_model.layers[test_layer] + attn_type = tm.attention_type + + # Inputs + t_inputs = self._process_torch_inputs_for_decoder_text_model(attn_type, **start_t_inputs) + nx = t_inputs["hidden_states"].detach().cpu().numpy() + batch_size, num_tokens, _ = nx.shape + nnx_cache = self._init_nnx_cache( + batch_size=batch_size, token_len=num_tokens, generate_steps=1, dtype=jnp.float32 + ) + + # NOTE: Using the HF attention mask for this test + mask = t_inputs["attention_mask"].detach().cpu().numpy()[:, :, :, :num_tokens] + + # run models + ty = tm(**t_inputs) + ny = nm(nx, nnx_cache[test_layer], jnp.ones((batch_size, num_tokens)), mask=mask) + + t_inputs["hidden_states"] = ty[0] + + found_exception = False + try: + np.testing.assert_allclose( + ny, ty[0].detach().cpu().numpy(), rtol=5e-3, atol=5e-3, err_msg=f"{test_layer}" + ) + except Exception as e: + print(e) + found_exception = True + assert not found_exception, "FOUND EXCEPTION" + + # multi modal tests + + # @unittest.skip("Done") + def test_multi_modal_projector(self): + t_inputs = self._make_torch_input() + tm = self.torch_model + nm = self.bonsai_model.multi_modal_projector + + tx = tm.vision_tower(t_inputs["pixel_values"]).last_hidden_state + nx = tx.detach().cpu().numpy() + + ty = tm.multi_modal_projector(tx) + ny = nm(nx) + + torch.testing.assert_close(torch.tensor(ny), ty, rtol=1e-4, atol=1e-4) + + # @unittest.skip("Done") + def test_text_image_merge(self): + nm = self.bonsai_model + t_inputs = self._make_torch_input() + t_out = self._process_torch_inputs(**t_inputs) + + # answer is input_embeds + t_ans = t_out["inputs_embeds"] + + tmp = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_text = nm.language_model.embed_tokens(tmp) + + # return + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_img = nm.vision_tower(n_img) + n_img = nm.multi_modal_projector(n_img) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + n_ans = modeling.batched_merge_modalities(n_img, n_text, n_tti) + + np.testing.assert_allclose(n_ans, t_ans.detach().cpu().numpy(), rtol=1e-3, atol=1e-3) + + # @unittest.skip("Done") + def test_text_layers_in_order(self): + start_t_inputs = self._make_torch_input() + start_t_inputs = self._process_torch_inputs(**start_t_inputs) + ny, ty = None, None + + for test_layer in trange(34): + # Models + tm = self.torch_model.language_model.layers[test_layer] + nm = self.bonsai_model.language_model.layers[test_layer] + attn_type = tm.attention_type + + # Inputs + t_inputs = self._process_torch_inputs_for_decoder_text_model(attn_type, **start_t_inputs) + if ty is not None: + t_inputs["hidden_states"] = ty[0] + if ny is None: + nx = t_inputs["hidden_states"].detach().cpu().numpy() + else: + nx = ny + batch_size, num_tokens, _ = nx.shape + nnx_cache = self._init_nnx_cache( + batch_size=batch_size, token_len=num_tokens, generate_steps=1, dtype=jnp.float32 + ) + + # NOTE: Using the HF attention mask here + mask = t_inputs["attention_mask"].detach().cpu().numpy()[:, :, :, :num_tokens] + + # run models + ty = tm(**t_inputs) + ny = nm(nx, nnx_cache[test_layer], jnp.ones((batch_size, num_tokens)), mask=mask) + + found_exception = False + try: + np.testing.assert_allclose(ny, ty[0].detach().cpu().numpy(), rtol=1, atol=1, err_msg=f"{test_layer}") + except Exception as e: + print(e) + found_exception = True + assert not found_exception, "FOUND EXCEPTION" + + # @unittest.skip("Done") + def test_masks(self): + # Make a really long input so we can test the sliding window + # This only tests for the pre-fill stage + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail." + "hello " * 1500}, + ], + }, + ] + + t_inputs = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + t_inputs["pixel_values"] = t_inputs["pixel_values"].to(dtype=torch.float32) + + batch_size, num_tokens = t_inputs["input_ids"].shape + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + n_mask = modeling.make_causal_mask(batch_size, num_tokens, n_tti) + + # Full attention + t_inputs = self._process_torch_inputs(**t_inputs) + t_mask = t_inputs["attention_mask"]["full_attention"] + + np.testing.assert_allclose(n_mask, t_mask.detach().cpu().numpy()) + + # Sliding attention + t_mask = t_inputs["attention_mask"]["sliding_attention"] + n_mask = modeling.make_window_mask(batch_size, num_tokens, n_tti, self.bonsai_config.text_config.sliding_window) + + np.testing.assert_allclose(n_mask, t_mask.detach().cpu().numpy()) + + # @unittest.skip("Done but come back to this") + def test_full_in_order(self): + tm = self.torch_model + nm = self.bonsai_model + + # Torch inputs + t_inputs = self._make_torch_input() + + # NNX inputs + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + # Get masks + n_causal_mask = modeling.make_causal_mask(n_text.shape[0], n_text.shape[1], n_tti) + n_sliding_mask = modeling.make_window_mask(n_text.shape[0], n_text.shape[1], n_tti, 1024) + + # text embeds + t_inputs_embeds = tm.language_model.embed_tokens(t_inputs["input_ids"]) + n_inputs_embeds = nm.language_model.embed_tokens(n_text) + np.testing.assert_allclose(n_inputs_embeds, t_inputs_embeds.detach().cpu().numpy(), err_msg="text emb") + + # Vision part + t_vis = tm.vision_tower(t_inputs["pixel_values"]).last_hidden_state + n_vis = nm.vision_tower(n_img) + # Mismatched elements: 4608354 / 4718592 (97.7%) + # Max absolute difference among violations: 0.00756264 + # Max relative difference among violations: 15.521739 + np.testing.assert_allclose(n_vis, t_vis.detach().cpu().numpy(), rtol=1e-3, atol=1e-3, err_msg="vis tower") + + # MM Proj part + t_img_feat = tm.multi_modal_projector(t_vis) + n_img_feat = nm.multi_modal_projector(n_vis) + # Mismatched elements: 648574 / 655360 (99%) + # Max absolute difference among violations: 0.00063944 + # Max relative difference among violations: 20.392141 + np.testing.assert_allclose( + n_img_feat, t_img_feat.detach().cpu().numpy(), rtol=1e-3, atol=1e-3, err_msg="mm proj" + ) + + # Merging part + special_image_mask = tm.get_placeholder_mask( + t_inputs["input_ids"], inputs_embeds=t_inputs_embeds, image_features=t_img_feat + ) + t_inputs_embeds = t_inputs_embeds.masked_scatter(special_image_mask, t_img_feat) + n_inputs_embeds = modeling.batched_merge_modalities(n_img_feat, n_inputs_embeds, n_tti) + # Mismatched elements: 648574 / 719360 (90.2%) + # Max absolute difference among violations: 0.00063944 + # Max relative difference among violations: 20.392141 + np.testing.assert_allclose( + n_inputs_embeds, t_inputs_embeds.detach().cpu().numpy(), rtol=1e-3, atol=1e-3, err_msg="merge" + ) + + # NOTE: Text part in order + t_inputs["output_hidden_states"] = True + t_text_inputs = self._process_torch_inputs(**t_inputs) + t_hidden_states = tm.language_model(**t_text_inputs).hidden_states + assert len(t_hidden_states) - 1 == len(nm.language_model.layers), ( + f"{len(t_hidden_states)} vs {len(nm.language_model.layers)}" + ) + + # check inputs + nx = n_inputs_embeds + + n_hidden_states = [] + for i, layer in enumerate(nm.language_model.layers): + attn_type = tm.language_model.layers[i].attention_type + n_mask = n_causal_mask if attn_type == "full_attention" else n_sliding_mask + n_hidden_states.append(nx) + nx = layer(nx, cache[i], segment_ids, n_mask) + nx = nm.language_model.norm(nx) + n_hidden_states.append(nx) + + for i, (nval, tval) in enumerate(zip(n_hidden_states, t_hidden_states)): + try: + np.testing.assert_allclose(nval, tval.detach().cpu().numpy(), err_msg=f"text {i}") + except Exception as e: + print(e) + found_error = True + assert not found_error, "Found errors in text decoder layers" + # NOTE: some errors are expected here since errors compound with layer + + # @unittest.skip("Done") + def test_full(self): + tm = self.torch_model + nm = self.bonsai_model + + t_inputs = self._make_torch_input() + + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + ny = nm(n_text, n_img, cache, segment_ids, n_tti) + ty = tm(**t_inputs) + + torch.testing.assert_close(torch.tensor(ny), ty.last_hidden_state, rtol=5e-2, atol=5e-2) + + @unittest.skip("TODO") + def test_full_batched(self): + # TODO: This isn't implemented yet + raise NotImplementedError("Need to test against batched inputs") + + +if __name__ == "__main__": + absltest.main() diff --git a/pyproject.toml b/pyproject.toml index ec325a3d..029cd390 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ qwen3 = [] resnet50 = [] sam2 = ["pillow>=11.3.0"] vgg19 = ["h5py", "keras_hub", "tensorflow"] +gemma3 = ["sentencepiece"] dev = [ "xprof", From 9434fbe74708786ce8bdcd91288662eb8d3973c2 Mon Sep 17 00:00:00 2001 From: James Chapman Date: Thu, 11 Dec 2025 00:51:51 -0800 Subject: [PATCH 2/3] kv cache --- bonsai/models/gemma3/README.md | 17 ++-- bonsai/models/gemma3/modeling.py | 53 ++++++------ bonsai/models/gemma3/params.py | 2 - bonsai/models/gemma3/tests/run_model.py | 36 +++----- .../gemma3/tests/test_outputs_gemma3.py | 86 +++++++++++-------- 5 files changed, 97 insertions(+), 97 deletions(-) diff --git a/bonsai/models/gemma3/README.md b/bonsai/models/gemma3/README.md index 8b5cbd04..066cc6c8 100644 --- a/bonsai/models/gemma3/README.md +++ b/bonsai/models/gemma3/README.md @@ -21,17 +21,12 @@ python3 -m bonsai.models.gemma3.tests.run_model ### Remaining Tasks -1. Implement KV caching to speed up inference -2. JIT Compile forward pass -3. Finish the `run_model.py` example. Add timing and profiling. -4. Optimize based on the profiling. -5. Implement sharding. -6. Get the `lm_head` from the weights. -7. Update to include other model sizes - - -### Implementation Notes -The implementation matches the HF one pretty well. To get KV caching working, we have to pad things. We also have to pad the token_type_ids on the right with 0's. +1. Finish the `run_model.py` example. Add timing and profiling. +2. Optimize based on the profiling. +3. Implement sharding. +4. Update to include other model sizes +5. Clean up code (variable names, etc.) +6. Implement with batching diff --git a/bonsai/models/gemma3/modeling.py b/bonsai/models/gemma3/modeling.py index 8001dd5f..b085e09e 100644 --- a/bonsai/models/gemma3/modeling.py +++ b/bonsai/models/gemma3/modeling.py @@ -26,8 +26,6 @@ from jax.interpreters import pxla from jaxtyping import Array, Float -_K_MASK = jax._src.nn.functions._get_large_negative(jax.numpy.float32).item() - class AttentionType(Enum): FULL = "full_attention" @@ -373,17 +371,13 @@ def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, mask: # Update cache slice_indices = (0, cache.cur_ind.value, 0, 0) - cache.v_cache.value = jax.lax.dynamic_update_slice(cache.v_cache.value, v, slice_indices) cache.k_cache.value = jax.lax.dynamic_update_slice(cache.k_cache.value, k, slice_indices) - t = q.shape[1] - cache.cur_ind.value += x.shape[1] + cache.v_cache.value = jax.lax.dynamic_update_slice(cache.v_cache.value, v, slice_indices) - # TODO: Need to do this with the kv cache next - k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) - qkv = jax.nn.dot_product_attention(q, k, v, is_causal=False, mask=mask[:, :, :, :t], scale=self.scale) - # k, v = repeat_kv(cache.k_cache.value, self.n_rep), repeat_kv(cache.v_cache.value, self.n_rep) - # qkv = jax.nn.dot_product_attention(q, k, v, is_causal=False, mask=mask[:, :, :, :t], scale=self.scale) + k, v = repeat_kv(cache.k_cache.value, self.n_rep), repeat_kv(cache.v_cache.value, self.n_rep) + qkv = jax.nn.dot_product_attention(q, k, v, is_causal=False, mask=mask, scale=self.scale) + t = x.shape[1] cache.cur_ind.value = cache.cur_ind.value + t return self.o_proj(qkv.reshape(*x.shape[:-1], -1)) @@ -483,22 +477,25 @@ def __call__(self, vision_outputs: Array) -> Array: return x.astype(vision_outputs.dtype) -# def make_causal_mask(cache_layer: LayerCache, token_type_ids): -# pass - - -def make_causal_mask(b: int, t: int, token_type_ids: Array): - my_mask = nnx.make_causal_mask(jnp.ones((b, t))) +def make_causal_mask(x: Array, layer_cache: LayerCache, token_type_ids: Array): + _, t = x.shape + c = layer_cache.size + tmp1 = jnp.arange(t) + tmp2 = jnp.arange(c) + my_mask = tmp1[:, None] - tmp2[None, :] >= -layer_cache.cur_ind tti = token_type_ids.astype(jnp.bool_) - or_mask = tti[:, None, None, :] & tti[:, None, :, None] + tmp3 = jnp.concat([tti, jnp.zeros((1, c - t), dtype=jnp.bool_)], axis=-1) + or_mask = tti[:, None, :, None] & tmp3[:, None, None, :] my_mask = my_mask.astype(jnp.bool_) | or_mask return my_mask -def make_window_mask(b: int, t: int, token_type_ids: Array, slide_size: int): - my_mask = make_causal_mask(b, t, token_type_ids) - tmp = jnp.arange(my_mask.shape[-1]) - slide = tmp[:, None] - tmp[None, :] < slide_size +def make_window_mask(x, layer_cache, token_type_ids, slide_size: int): + my_mask = make_causal_mask(x, layer_cache, token_type_ids) + *_, t, c = my_mask.shape + tmp1 = jnp.arange(t) + tmp2 = jnp.arange(c) + slide = tmp1[:, None] - tmp2[None, :] < slide_size return my_mask & slide @@ -525,6 +522,7 @@ def batched_merge_modalities(img_emb: Array, text_emb: Array, token_mask: Array) class Gemma3Model(nnx.Module): def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): + self.sliding_window_size = cfg.text_config.sliding_window self.vision_tower = SiglipVisionTransformer(cfg.vision_config, rngs=rngs) self.multi_modal_projector = Gemma3MultiModalProjector(cfg, rngs=rngs) self.language_model = Gemma3TextModel(cfg.text_config, rngs=rngs) @@ -532,8 +530,9 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): def __call__( self, input_ids: Array, pixel_values: Array, cache: Cache, segment_ids: Array, token_type_ids: Array ) -> Array: - causal_mask = make_causal_mask(input_ids.shape[0], input_ids.shape[1], token_type_ids) - sliding_mask = make_causal_mask(input_ids.shape[0], input_ids.shape[1], token_type_ids) + assert input_ids.shape == token_type_ids.shape + causal_mask = make_causal_mask(input_ids, cache[0], token_type_ids) + sliding_mask = make_window_mask(input_ids, cache[0], token_type_ids, slide_size=self.sliding_window_size) inputs_embeds = self.language_model.embed_tokens(input_ids) @@ -546,7 +545,13 @@ def __call__( inputs_embeds = batched_merge_modalities(image_features, inputs_embeds, token_type_ids) out = self.language_model(inputs_embeds, cache, segment_ids, sliding_mask, causal_mask) + out = self.language_model.embed_tokens.weight.attend(out) return out -# TODO: Implement a jitted forward method +@jax.jit +def forward( + model: nnx.Module, cache: Cache, input_ids: Array, pixel_values: Array, segment_ids: Array, token_type_ids +) -> tuple[Array, nnx.Cache]: + logits = model(input_ids, pixel_values, cache, segment_ids, token_type_ids) + return logits[:, -1:None, :], cache diff --git a/bonsai/models/gemma3/params.py b/bonsai/models/gemma3/params.py index 29ca0f51..c84275c0 100644 --- a/bonsai/models/gemma3/params.py +++ b/bonsai/models/gemma3/params.py @@ -43,7 +43,6 @@ class Transform(Enum): EMBED = None -# TODO: Need to get lm_head. It currently isn't used. def _get_key_and_transform_mapping(): # Mapping st_keys -> (nnx_keys, (permute_rule, reshape_rule)). return { @@ -180,7 +179,6 @@ def _get_key_and_transform_mapping(): r"vision_tower\.post_layernorm\.scale", Transform.DEFAULT, ), - r"lm_head\.weight": ("lm_head.kernel", Transform.LINEAR), } diff --git a/bonsai/models/gemma3/tests/run_model.py b/bonsai/models/gemma3/tests/run_model.py index 51b7f8a9..27185310 100644 --- a/bonsai/models/gemma3/tests/run_model.py +++ b/bonsai/models/gemma3/tests/run_model.py @@ -20,7 +20,6 @@ import jax.numpy as jnp import numpy as np import tqdm -from flax import nnx from bonsai.models.gemma3 import modeling, params @@ -33,7 +32,7 @@ import os import torch -from transformers import AutoModel, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration +from transformers import AutoProcessor def make_input(processor, dtype=torch.float32, msg1=True): @@ -82,37 +81,28 @@ def run_model(): model_ckpt_path = snapshot_download(model_name) bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path) - # # Dummy token ids + # Make inputs t_inputs = make_input(processor) - - t_lm_head = Gemma3ForConditionalGeneration.from_pretrained( - "google/gemma-3-4b-it", - dtype=torch.float32, - ).lm_head - - full_tokens = t_inputs["input_ids"] - n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) - gen_steps = 30 + gen_steps = 200 + batch_size, num_tokens = n_text.shape + cache = modeling.init_cache(cfg, batch_size, num_tokens, gen_steps, jnp.float32) + + all_tokens = [n_text] for i in tqdm.trange(gen_steps): batch_size, num_tokens = n_text.shape segment_ids = jnp.ones((batch_size, num_tokens)) - cache = modeling.init_cache(cfg, batch_size, num_tokens, 1, jnp.float32) - - out = bonsai_model(n_text, n_img, cache, segment_ids, n_tti) - out = torch.tensor(out) - - out = t_lm_head(out[:, -1:None, :]) - - out = torch.argmax(out, axis=-1) + out, cache = modeling.forward(bonsai_model, cache, n_text, n_img, segment_ids, n_tti) - full_tokens = torch.concat([full_tokens, out], axis=-1) - n_text = full_tokens.detach().cpu().numpy() - n_tti = jnp.concatenate([n_tti, n_tti[:, -1:None]], axis=-1) + n_text = jnp.argmax(out, axis=-1) + all_tokens.append(n_text) + n_tti = n_tti[:, -1:None] # this assumes that the input prompt ends with text + n_img = None + full_tokens = torch.tensor(jnp.concat(all_tokens, axis=1)) out_tokens = processor.decode(full_tokens[0], skip_special_tokens=True) print(out_tokens) diff --git a/bonsai/models/gemma3/tests/test_outputs_gemma3.py b/bonsai/models/gemma3/tests/test_outputs_gemma3.py index 20768824..5ecdce38 100644 --- a/bonsai/models/gemma3/tests/test_outputs_gemma3.py +++ b/bonsai/models/gemma3/tests/test_outputs_gemma3.py @@ -12,7 +12,7 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer, Gemma3Model, SiglipModel from transformers.cache_utils import DynamicCache from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask -from transformers.models.gemma3 import Gemma3ForCausalLM +from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3ForConditionalGeneration from transformers.models.gemma3.modeling_gemma3 import token_type_ids_mask_function from bonsai.models.gemma3 import modeling, params @@ -32,7 +32,7 @@ def setUpClass(cls): ## models cls.torch_model = ( - AutoModel.from_pretrained(cls.model_name, dtype="auto") + Gemma3ForConditionalGeneration.from_pretrained(cls.model_name, dtype="auto") .to(device=cls.torch_device, dtype=torch.float32) .eval() ) @@ -115,7 +115,7 @@ def _process_torch_inputs( llm_input_ids = input_ids if inputs_embeds is None: - inputs_embeds = self.torch_model.get_input_embeddings()(llm_input_ids) + inputs_embeds = self.torch_model.model.get_input_embeddings()(llm_input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -125,9 +125,9 @@ def _process_torch_inputs( # Merge text and images if pixel_values is not None: - image_features = self.torch_model.get_image_features(pixel_values) + image_features = self.torch_model.model.get_image_features(pixel_values) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - special_image_mask = self.torch_model.get_placeholder_mask( + special_image_mask = self.torch_model.model.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) @@ -208,7 +208,7 @@ def _process_torch_inputs_for_decoder_text_model( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None and not training: - past_key_values = DynamicCache(config=self.torch_model.config.text_config) + past_key_values = DynamicCache(config=self.torch_model.model.config.text_config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -243,8 +243,8 @@ def _process_torch_inputs_for_decoder_text_model( "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), } - position_embeddings_global = self.torch_model.language_model.rotary_emb(inputs_embeds, position_ids) - position_embeddings_local = self.torch_model.language_model.rotary_emb_local(inputs_embeds, position_ids) + position_embeddings_global = self.torch_model.model.language_model.rotary_emb(inputs_embeds, position_ids) + position_embeddings_local = self.torch_model.model.language_model.rotary_emb_local(inputs_embeds, position_ids) return dict( hidden_states=inputs_embeds, position_embeddings_global=position_embeddings_global, @@ -270,7 +270,7 @@ def _init_nnx_cache(self, batch_size: int, token_len: int, generate_steps: int, # Vision tests # @unittest.skip("Done") def test_image_emb(self): - tm = self.torch_model.vision_tower.vision_model.embeddings + tm = self.torch_model.model.vision_tower.vision_model.embeddings nm = self.bonsai_model.vision_tower.embeddings t_inputs = self._make_torch_input() @@ -287,7 +287,7 @@ def test_image_emb(self): # @unittest.skip("Done") def test_siglip_encoder_layer(self): - tm = self.torch_model.vision_tower.vision_model.encoder.layers[0] + tm = self.torch_model.model.vision_tower.vision_model.encoder.layers[0] nm = self.bonsai_model.vision_tower.encoder.layers[0] tx = torch.randn((1, 4096, 1152), device=self.torch_device) @@ -299,10 +299,10 @@ def test_siglip_encoder_layer(self): np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4) - # @unittest.skip("Done but double check") + # @unittest.skip("Done") def test_vision_model(self): # only have deviations on .0567% of the entries and on order 7e-3 - tm = self.torch_model.vision_tower + tm = self.torch_model.model.vision_tower nm = self.bonsai_model.vision_tower t_inputs = self._make_torch_input() @@ -320,7 +320,7 @@ def test_vision_model(self): # @unittest.skip("Done") def test_text_embedding(self): self._upgrade_dtypes() - tm = self.torch_model.language_model.embed_tokens + tm = self.torch_model.model.language_model.embed_tokens nm = self.bonsai_model.language_model.embed_tokens torch.testing.assert_close(torch.tensor(nm.weight.embedding.value), tm.weight.cpu()) @@ -339,7 +339,7 @@ def test_text_embedding(self): # @unittest.skip("Done") def test_attn_projs(self): - tm = self.torch_model.language_model.layers[0].self_attn + tm = self.torch_model.model.language_model.layers[0].self_attn nm = self.bonsai_model.language_model.layers[0].self_attn tx = torch.randn((1, 281, 2560), device=self.torch_device) @@ -365,7 +365,7 @@ def test_attn_projs(self): # @unittest.skip("Done") def test_attn_norms(self): - tm = self.torch_model.language_model.layers[0].self_attn + tm = self.torch_model.model.language_model.layers[0].self_attn nm = self.bonsai_model.language_model.layers[0].self_attn tx = torch.randn((1, 281, 2048), device=self.torch_device).reshape(1, 281, -1, 256) @@ -395,7 +395,7 @@ def test_sin_cos(self): # local uses default rt = self.bonsai_config.text_config.rope_local_base_freq js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=1.0) - rot_emb = self.torch_model.language_model.rotary_emb_local + rot_emb = self.torch_model.model.language_model.rotary_emb_local tc, ts = rot_emb(hidden_states, torch.tensor(jp)) tc, ts = tc[:, :, : dim // 2], ts[:, :, : dim // 2] torch.testing.assert_close(torch.tensor(js), ts) @@ -404,7 +404,7 @@ def test_sin_cos(self): # global uses linear rt = self.bonsai_config.text_config.rope_theta js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=8.0) - rot_emb = self.torch_model.language_model.rotary_emb + rot_emb = self.torch_model.model.language_model.rotary_emb tc, ts = rot_emb(hidden_states, torch.tensor(jp)) tc, ts = tc[:, :, : dim // 2], ts[:, :, : dim // 2] torch.testing.assert_close(torch.tensor(js), ts) @@ -412,12 +412,12 @@ def test_sin_cos(self): # @unittest.skip("Done") def test_text_decoder_layer(self): - start_t_inputs = self._make_torch_input() - start_t_inputs = self._process_torch_inputs(**start_t_inputs) + first_t_inputs = self._make_torch_input() + start_t_inputs = self._process_torch_inputs(**first_t_inputs) for test_layer in trange(34): # Models - tm = self.torch_model.language_model.layers[test_layer] + tm = self.torch_model.model.language_model.layers[test_layer] nm = self.bonsai_model.language_model.layers[test_layer] attn_type = tm.attention_type @@ -428,9 +428,12 @@ def test_text_decoder_layer(self): nnx_cache = self._init_nnx_cache( batch_size=batch_size, token_len=num_tokens, generate_steps=1, dtype=jnp.float32 ) + n_tti = first_t_inputs["token_type_ids"].detach().cpu().numpy() - # NOTE: Using the HF attention mask for this test - mask = t_inputs["attention_mask"].detach().cpu().numpy()[:, :, :, :num_tokens] + if attn_type == "full_attention": + mask = modeling.make_causal_mask(n_tti, nnx_cache[test_layer], n_tti) + else: + mask = modeling.make_window_mask(n_tti, nnx_cache[test_layer], n_tti, 1024) # run models ty = tm(**t_inputs) @@ -453,7 +456,7 @@ def test_text_decoder_layer(self): # @unittest.skip("Done") def test_multi_modal_projector(self): t_inputs = self._make_torch_input() - tm = self.torch_model + tm = self.torch_model.model nm = self.bonsai_model.multi_modal_projector tx = tm.vision_tower(t_inputs["pixel_values"]).last_hidden_state @@ -488,13 +491,13 @@ def test_text_image_merge(self): # @unittest.skip("Done") def test_text_layers_in_order(self): - start_t_inputs = self._make_torch_input() - start_t_inputs = self._process_torch_inputs(**start_t_inputs) + first_t_inputs = self._make_torch_input() + start_t_inputs = self._process_torch_inputs(**first_t_inputs) ny, ty = None, None for test_layer in trange(34): # Models - tm = self.torch_model.language_model.layers[test_layer] + tm = self.torch_model.model.language_model.layers[test_layer] nm = self.bonsai_model.language_model.layers[test_layer] attn_type = tm.attention_type @@ -510,9 +513,12 @@ def test_text_layers_in_order(self): nnx_cache = self._init_nnx_cache( batch_size=batch_size, token_len=num_tokens, generate_steps=1, dtype=jnp.float32 ) + n_tti = first_t_inputs["token_type_ids"].detach().cpu().numpy() - # NOTE: Using the HF attention mask here - mask = t_inputs["attention_mask"].detach().cpu().numpy()[:, :, :, :num_tokens] + if attn_type == "full_attention": + mask = modeling.make_causal_mask(n_tti, nnx_cache[test_layer], n_tti) + else: + mask = modeling.make_window_mask(n_tti, nnx_cache[test_layer], n_tti, 1024) # run models ty = tm(**t_inputs) @@ -550,24 +556,28 @@ def test_masks(self): t_inputs["pixel_values"] = t_inputs["pixel_values"].to(dtype=torch.float32) batch_size, num_tokens = t_inputs["input_ids"].shape + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) - n_mask = modeling.make_causal_mask(batch_size, num_tokens, n_tti) + gen_steps = 10 + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, gen_steps) + n_mask = modeling.make_causal_mask(n_text, cache[0], n_tti) # Full attention t_inputs = self._process_torch_inputs(**t_inputs) t_mask = t_inputs["attention_mask"]["full_attention"] + size_for_comp = t_mask.shape[-1] - np.testing.assert_allclose(n_mask, t_mask.detach().cpu().numpy()) + np.testing.assert_allclose(n_mask[:, :, :, :size_for_comp], t_mask.detach().cpu().numpy()) # Sliding attention t_mask = t_inputs["attention_mask"]["sliding_attention"] - n_mask = modeling.make_window_mask(batch_size, num_tokens, n_tti, self.bonsai_config.text_config.sliding_window) + n_mask = modeling.make_window_mask(n_text, cache[0], n_tti, self.bonsai_config.text_config.sliding_window) - np.testing.assert_allclose(n_mask, t_mask.detach().cpu().numpy()) + np.testing.assert_allclose(n_mask[:, :, :, :size_for_comp], t_mask.detach().cpu().numpy()) - # @unittest.skip("Done but come back to this") + @unittest.skip("Skipping - this test is just to observe errors over full model evaluation") def test_full_in_order(self): - tm = self.torch_model + tm = self.torch_model.model nm = self.bonsai_model # Torch inputs @@ -582,8 +592,10 @@ def test_full_in_order(self): cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) # Get masks - n_causal_mask = modeling.make_causal_mask(n_text.shape[0], n_text.shape[1], n_tti) - n_sliding_mask = modeling.make_window_mask(n_text.shape[0], n_text.shape[1], n_tti, 1024) + n_causal_mask = modeling.make_causal_mask(n_text, cache[0], n_tti) + n_sliding_mask = modeling.make_window_mask( + n_text, cache[0], n_tti, self.bonsai_config.text_config.sliding_window + ) # text embeds t_inputs_embeds = tm.language_model.embed_tokens(t_inputs["input_ids"]) @@ -667,7 +679,7 @@ def test_full(self): ny = nm(n_text, n_img, cache, segment_ids, n_tti) ty = tm(**t_inputs) - torch.testing.assert_close(torch.tensor(ny), ty.last_hidden_state, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(torch.tensor(ny), ty.logits, rtol=5e-2, atol=5e-2) @unittest.skip("TODO") def test_full_batched(self): From 3ddf0b0512f72dc45e88a03e384c1b40c01bc7f3 Mon Sep 17 00:00:00 2001 From: James Chapman Date: Fri, 12 Dec 2025 16:29:27 -0800 Subject: [PATCH 3/3] Added partial support for sharding Updated configs Moved embed_tokens to more natural place Updated run_model to use sampler and stop at end_of_turn token Added test_sharding_gemma3 Added batched forward test. Need more complex behavior and testing --- bonsai/models/gemma3/README.md | 14 +- bonsai/models/gemma3/modeling.py | 454 +++++++++++++----- bonsai/models/gemma3/params.py | 5 +- bonsai/models/gemma3/tests/run_model.py | 54 ++- .../gemma3/tests/test_outputs_gemma3.py | 99 ++-- .../gemma3/tests/test_sharding_gemma3.py | 96 ++++ 6 files changed, 528 insertions(+), 194 deletions(-) create mode 100644 bonsai/models/gemma3/tests/test_sharding_gemma3.py diff --git a/bonsai/models/gemma3/README.md b/bonsai/models/gemma3/README.md index 066cc6c8..6f1e6685 100644 --- a/bonsai/models/gemma3/README.md +++ b/bonsai/models/gemma3/README.md @@ -21,12 +21,8 @@ python3 -m bonsai.models.gemma3.tests.run_model ### Remaining Tasks -1. Finish the `run_model.py` example. Add timing and profiling. -2. Optimize based on the profiling. -3. Implement sharding. -4. Update to include other model sizes -5. Clean up code (variable names, etc.) -6. Implement with batching - - - +1. Properly implement sharding (vision, then text) +2. Implement with batching. Need this for FSDP. +3. Optimize based on the profiling. +4. Clean up code (variable names, etc.). Simplify unused configs (marked these with TODO) +5. Update to include other model sizes diff --git a/bonsai/models/gemma3/modeling.py b/bonsai/models/gemma3/modeling.py index b085e09e..d3680a35 100644 --- a/bonsai/models/gemma3/modeling.py +++ b/bonsai/models/gemma3/modeling.py @@ -21,82 +21,262 @@ import flax import jax import jax.numpy as jnp -import jax.sharding as shd from flax import nnx +from jax import P from jax.interpreters import pxla +from jax.sharding import PartitionSpec, get_abstract_mesh, reshard from jaxtyping import Array, Float -class AttentionType(Enum): +class AttentionMode(Enum): FULL = "full_attention" SLIDE = "sliding_attention" -def _make_attn_types(): +class ShardMode(Enum): + FSDP = "fsdp" + TP = "tp" + + +def _make_attn_types(full_freq: int): # Fix this (5x slide 1x full) x 5 + (4x slide) - return [AttentionType.FULL if i % 6 == 5 else AttentionType.SLIDE for i in range(34)] + return [AttentionMode.FULL if i % full_freq == full_freq - 1 else AttentionMode.SLIDE for i in range(34)] + + +# TODO: Need to get these implemented and follow the activations +@dataclass(slots=True, frozen=True) +class VisionShardingCfg: + qk: PartitionSpec + qb: PartitionSpec + kk: PartitionSpec + kb: PartitionSpec + vk: PartitionSpec + vb: PartitionSpec + ok: PartitionSpec + ob: PartitionSpec + fc1k: PartitionSpec + fc1b: PartitionSpec + fc2k: PartitionSpec + fc2b: PartitionSpec + + @staticmethod + def no_sharding(): + return VisionShardingCfg.default(False, False) + + @staticmethod + def default(use_fsdp: bool, use_tp: bool): + fsdp = ShardMode.FSDP.value if use_fsdp else None + tp = ShardMode.TP.value if use_tp else None + return VisionShardingCfg( + qk=P(tp, fsdp), + qb=P(fsdp), + kk=P(tp, fsdp), + kb=P(), # TODO + vk=P(tp, fsdp), + vb=P(), # TODO + ok=P(tp, fsdp), + ob=P(), # TODO + fc1k=P(fsdp, tp), + fc1b=P(), # TODO + fc2k=P(tp, fsdp), + fc2b=P(), # TODO + ) + + +@dataclass(slots=True, frozen=True) +class TextShardingCfg: + # attn + qk: PartitionSpec + qb: PartitionSpec + kk: PartitionSpec + kb: PartitionSpec + vk: PartitionSpec + vb: PartitionSpec + ok: PartitionSpec + ob: PartitionSpec + # mlp + dpk: PartitionSpec + dpb: PartitionSpec + gpk: PartitionSpec + gpb: PartitionSpec + upk: PartitionSpec + upb: PartitionSpec + # cache + # TODO + + @staticmethod + def no_sharding(): + return VisionShardingCfg.default(False, False) + + @staticmethod + def default(use_fsdp: bool, use_tp: bool): + fsdp = ShardMode.FSDP.value if use_fsdp else None + tp = ShardMode.TP.value if use_tp else None + return TextShardingCfg( + qk=P(tp, fsdp), + qb=P(fsdp), + kk=P(tp, fsdp), + kb=P(), # TODO + vk=P(tp, fsdp), + vb=P(), # TODO + ok=P(tp, fsdp), + ob=P(), # TODO + dpk=P(tp, fsdp), + dpb=P(), + gpk=P(tp, fsdp), + gpb=P(), + upk=P(tp, fsdp), + upb=P(), + ) -@dataclass +@dataclass(slots=True, frozen=True) +class ShardingCfg: + pass + + +@dataclass(frozen=True) class VisionConfig: - attention_dropout: float = 0.0 - hidden_act: str = "gelu_pytorch_tanh" - hidden_size: int = 1152 - image_size: int = 896 - intermediate_size: int = 4304 - layer_norm_eps: float = 1e-6 - num_attention_heads: int = 16 - num_channels: int = 3 - num_hidden_layers: int = 27 - patch_size: int = 14 - vision_use_head: bool = False - - -@dataclass + attention_dropout: float # TODO: unused + hidden_size: int + image_size: int + intermediate_size: int + layer_norm_eps: float + num_attention_heads: int + num_channels: int + num_hidden_layers: int + patch_size: int + vision_use_head: bool + shd_cfg: VisionShardingCfg + + @classmethod + def gemma3_4b(cls, use_fsdp: bool, use_tp: bool): + return cls( + attention_dropout=0.0, + hidden_size=1152, + image_size=896, + intermediate_size=4304, + layer_norm_eps=1e-6, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=27, + patch_size=14, + vision_use_head=False, + shd_cfg=VisionShardingCfg.default(use_fsdp, use_tp), + ) + + +@dataclass(frozen=True) class TextConfig: - _sliding_window_pattern: int = 6 - attention_bias: bool = False - attention_dropout: float = 0.0 - attn_logit_softcapping: Optional[float] = None - final_logit_softcapping: Optional[float] = None - head_dim: int = 256 - hidden_activation: str = "gelu_pytorch_tanh" - hidden_size: int = 2560 - initializer_range: float = 0.02 - intermediate_size: int = 10240 - layer_types: list[AttentionType] = field(default_factory=lambda: _make_attn_types()) - max_position_embeddings: int = 131072 - num_attention_heads: int = 8 - num_hidden_layers: int = 34 - num_key_value_heads: int = 4 - query_pre_attn_scalar: int = 256 - rms_norm_eps: float = 1e-6 - rope_local_base_freq: float = 10000.0 - rope_scaling: dict[str, Any] = field(default_factory=lambda: {"factor": 8.0, "rope_type": "linear"}) - rope_theta: float = 1000000.0 - sliding_window: int = 1024 - use_cache: bool = True - vocab_size: int = 262208 - - -@dataclass + attention_bias: bool + attention_dropout: float # TODO: unused + attn_logit_softcapping: Optional[float] # TODO: unused + final_logit_softcapping: Optional[float] # TODO: unused + head_dim: int + hidden_size: int + initializer_range: float # TODO: unused + intermediate_size: int + layer_types: list[AttentionMode] + max_position_embeddings: int # TODO: unused + num_attention_heads: int + num_hidden_layers: int + num_key_value_heads: int + query_pre_attn_scalar: int # TODO: unused + rms_norm_eps: float + rope_full_factor: float + rope_full_theta: float + rope_slide_factor: float + rope_slide_theta: float + sliding_window: int + use_cache: bool # TODO: unused + vocab_size: int + shd_cfg: TextShardingCfg + + @classmethod + def gemma3_4b(cls, use_fsdp: bool, use_tp: bool): + return cls( + attention_bias=False, + attention_dropout=0.0, # TODO: unused + attn_logit_softcapping=None, # TODO: unused + final_logit_softcapping=None, # TODO: unused + head_dim=256, + hidden_size=2560, + initializer_range=0.02, # TODO: unused + intermediate_size=10240, + layer_types=_make_attn_types(6), + max_position_embeddings=131072, # TODO: unused + num_attention_heads=8, + num_hidden_layers=34, + num_key_value_heads=4, + query_pre_attn_scalar=256, # TODO: unused + rms_norm_eps=1e-6, + rope_full_factor=8.0, + rope_full_theta=1000000.0, + rope_slide_factor=1.0, + rope_slide_theta=10000.0, + sliding_window=1024, + use_cache=True, # TODO: unused + vocab_size=262208, + shd_cfg=TextShardingCfg.default(use_fsdp, use_tp), + ) + + +@dataclass(frozen=True) class ModelConfig: - vision_config: VisionConfig = field(default_factory=lambda: VisionConfig()) - text_config: TextConfig = field(default_factory=lambda: TextConfig()) - mm_tokens_per_image: int = 256 - boi_token_index: int = 255999 - dtype: str = "bfloat16" - eoi_token_index: int = 256000 - eos_token_id: list[int] = field(default_factory=lambda: [1, 106]) - image_token_index: int = 262144 - initializer_range: float = 0.02 - mm_tokens_per_image: int = 256 + vision_config: VisionConfig + text_config: TextConfig + mm_tokens_per_image: int + boi_token_index: int # TODO: unused + dtype: str # TODO: unused + eoi_token_index: int # TODO: unused + eos_token_id: list[int] # TODO: unused + image_token_index: int # TODO: unused + initializer_range: float # TODO: unused + + @classmethod + def gemma3_4b(cls, use_fsdp: bool = False, use_tp: bool = False): + return cls( + vision_config=VisionConfig.gemma3_4b(use_fsdp, use_tp), + text_config=TextConfig.gemma3_4b(use_fsdp, use_tp), + mm_tokens_per_image=256, + boi_token_index=255999, # TODO: unused + dtype="bfloat16", # TODO: unused + eoi_token_index=256000, # TODO: unused + eos_token_id=field(default_factory=lambda: [1, 106]), # TODO: unused + image_token_index=262144, # TODO: unused + initializer_range=0.02, # TODO: unused + ) -## GENERAL +# General Components -## VISION +class SHLinear(nnx.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + *, + use_bias: bool = True, + kernel_sharding=P(), + bias_sharding=P(), + dtype=None, # TODO: Use this + rngs, + ): + kernel_initializer = jax.nn.initializers.lecun_normal() + self.kernel = nnx.Param( + kernel_initializer(rngs.params(), (in_dim, out_dim), dtype=dtype, out_sharding=kernel_sharding) + ) + if use_bias: + self.bias = nnx.Param(jnp.zeros((out_dim,), dtype=dtype, out_sharding=bias_sharding)) + else: + self.bias = nnx.data(jnp.zeros((out_dim,), dtype=dtype, out_sharding=bias_sharding)) + + def __call__(self, x, *, out_sharding=P()): + return jnp.matmul(x, self.kernel, out_sharding=out_sharding) + self.bias + + +# Vision Components # TODO: update to include interpolate_pos_encoding @@ -112,6 +292,7 @@ def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): padding="valid", rngs=rngs, ) + # TODO: shard this self.position_embedding = nnx.Embed(self.num_patches, config.hidden_size, rngs=rngs) self.position_ids = jnp.expand_dims(jnp.arange(self.num_patches), 0) @@ -127,10 +308,12 @@ def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): self.config = config self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads - self.k_proj = nnx.Linear(config.hidden_size, config.hidden_size, rngs=rngs) - self.v_proj = nnx.Linear(config.hidden_size, config.hidden_size, rngs=rngs) - self.q_proj = nnx.Linear(config.hidden_size, config.hidden_size, rngs=rngs) - self.out_proj = nnx.Linear(config.hidden_size, config.hidden_size, rngs=rngs) + hs = config.hidden_size + shd = config.shd_cfg + self.k_proj = SHLinear(hs, hs, kernel_sharding=shd.kk, bias_sharding=shd.kb, rngs=rngs) + self.v_proj = SHLinear(hs, hs, kernel_sharding=shd.vk, bias_sharding=shd.vb, rngs=rngs) + self.q_proj = SHLinear(hs, hs, kernel_sharding=shd.qk, bias_sharding=shd.qb, rngs=rngs) + self.out_proj = SHLinear(hs, hs, kernel_sharding=shd.ok, bias_sharding=shd.ob, rngs=rngs) def __call__(self, x: Array, attn_mask: Array | None): batch_size, seq_length, _ = x.shape @@ -146,12 +329,11 @@ def __call__(self, x: Array, attn_mask: Array | None): class SiglipMLP(nnx.Module): def __init__(self, config: VisionConfig, *, rngs: nnx.Rngs): self.config = config - self.act = jax.nn.gelu - self.fc1 = nnx.Linear(config.hidden_size, config.intermediate_size, rngs=rngs) - self.fc2 = nnx.Linear(config.intermediate_size, config.hidden_size, rngs=rngs) + self.fc1 = SHLinear(config.hidden_size, config.intermediate_size, rngs=rngs) + self.fc2 = SHLinear(config.intermediate_size, config.hidden_size, rngs=rngs) def __call__(self, x: Array): - return self.fc2(self.act(self.fc1(x))) + return self.fc2(jax.nn.gelu(self.fc1(x))) class SiglipEncoderLayer(nnx.Module): @@ -215,7 +397,7 @@ def __call__(self, pixel_values: Array): return x -## LANGUAGE +# Language components # from qwen3 @@ -229,13 +411,6 @@ def __init__(self, cfg: ModelConfig, batch_size: int, cache_size: int, dtype: jn self.cur_ind = nnx.Variable(jnp.zeros((), dtype=jnp.int32)) # scalar for compute efficiency. -def shard(x: jnp.ndarray, s: Tuple[str, ...]): - mesh = pxla.thread_resources.env.physical_mesh - if mesh.empty or jax.devices()[0].platform == "cpu": - return x - return jax.lax.with_sharding_constraint(x, shd.NamedSharding(mesh, shd.PartitionSpec(*s))) - - Cache: TypeAlias = list[LayerCache] @@ -263,11 +438,9 @@ def __call__(self, x: Array) -> Array: class Gemma3TextScaledWordEmbedding(nnx.Module): - def __init__( - self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0, *, rngs: nnx.Rngs - ): - self.weight = nnx.Embed(num_embeddings, embedding_dim, rngs=rngs) - self.embed_scale = jnp.array(embed_scale, dtype=jnp.bfloat16).astype(jnp.float32) + def __init__(self, cfg: TextConfig, *, rngs: nnx.Rngs): + self.weight = nnx.Embed(cfg.vocab_size, cfg.hidden_size, rngs=rngs) + self.embed_scale = jnp.array(cfg.hidden_size**0.5, dtype=jnp.bfloat16).astype(jnp.float32) def __call__(self, input_ids: Array): return self.weight(input_ids) * self.embed_scale @@ -330,26 +503,47 @@ class Gemma3Attention(nnx.Module): def __init__(self, config: TextConfig, layer_idx: int, *, rngs: nnx.Rngs): self.config = config self.layer_idx = layer_idx - self.use_sliding = config.layer_types[layer_idx] == AttentionType.SLIDE + self.use_sliding = config.layer_types[layer_idx] == AttentionMode.SLIDE self.num_kv_heads = self.config.num_key_value_heads self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.q_proj = nnx.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, use_bias=config.attention_bias, rngs=rngs + shd = config.shd_cfg + self.q_proj = SHLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + use_bias=config.attention_bias, + kernel_sharding=shd.qk, + bias_sharding=shd.qb, + rngs=rngs, ) - self.k_proj = nnx.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, use_bias=config.attention_bias, rngs=rngs + self.k_proj = SHLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + use_bias=config.attention_bias, + kernel_sharding=shd.kk, + bias_sharding=shd.kb, + rngs=rngs, ) - self.v_proj = nnx.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, use_bias=config.attention_bias, rngs=rngs + self.v_proj = SHLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + use_bias=config.attention_bias, + kernel_sharding=shd.vk, + bias_sharding=shd.vb, + rngs=rngs, ) - self.o_proj = nnx.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, use_bias=config.attention_bias, rngs=rngs + self.o_proj = SHLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + use_bias=config.attention_bias, + kernel_sharding=shd.ok, + bias_sharding=shd.ob, + rngs=rngs, ) self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, rngs=rngs) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, rngs=rngs) - self.rope_theta = config.rope_local_base_freq if self.use_sliding else config.rope_theta - self.factor = 1.0 if self.use_sliding else 8.0 + self.rope_theta = config.rope_slide_theta if self.use_sliding else config.rope_full_theta + self.factor = config.rope_slide_factor if self.use_sliding else config.rope_full_factor self.n_rep = config.num_attention_heads // config.num_key_value_heads self.scale = config.head_dim**-0.5 @@ -385,9 +579,31 @@ def __call__(self, x: Array, cache: LayerCache | None, segment_ids: Array, mask: class Gemma3MLP(nnx.Module): def __init__(self, config: TextConfig, *, rngs: nnx.Rngs): self.config = config - self.gate_proj = nnx.Linear(config.hidden_size, config.intermediate_size, use_bias=False, rngs=rngs) - self.up_proj = nnx.Linear(config.hidden_size, config.intermediate_size, use_bias=False, rngs=rngs) - self.down_proj = nnx.Linear(config.intermediate_size, config.hidden_size, use_bias=False, rngs=rngs) + shd = config.shd_cfg + self.gate_proj = SHLinear( + config.hidden_size, + config.intermediate_size, + use_bias=False, + kernel_sharding=shd.gpk, + bias_sharding=shd.gpb, + rngs=rngs, + ) + self.up_proj = SHLinear( + config.hidden_size, + config.intermediate_size, + use_bias=False, + kernel_sharding=shd.upk, + bias_sharding=shd.upb, + rngs=rngs, + ) + self.down_proj = SHLinear( + config.intermediate_size, + config.hidden_size, + use_bias=False, + kernel_sharding=shd.dpk, + bias_sharding=shd.dpb, + rngs=rngs, + ) def __call__(self, x: Array): return self.down_proj(jax.nn.gelu(self.gate_proj(x)) * self.up_proj(x)) @@ -427,23 +643,14 @@ def head_dim(self): class Gemma3TextModel(nnx.Module): def __init__(self, config: TextConfig, *, rngs: nnx.Rngs): self.config = config - # TODO: Move this out of this class into the larger class - self.embed_tokens = Gemma3TextScaledWordEmbedding( - config.vocab_size, - config.hidden_size, - "self.padding_idx", - embed_scale=self.config.hidden_size**0.5, - rngs=rngs, - ) self.layers = nnx.List( [Gemma3DecoderLayer(config, layer_idx, rngs=rngs) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps, rngs=rngs) def __call__(self, x, cache: Cache, segment_ids: Array, sliding_mask: Array | None, causal_mask: Array | None): - # x = self.embed_tokens(x) # done in higher layer now for lt, c, layer in zip(self.config.layer_types, cache, self.layers): - mask = sliding_mask if lt == AttentionType.SLIDE else causal_mask + mask = sliding_mask if lt == AttentionMode.SLIDE else causal_mask x = layer(x, c, segment_ids, mask) return self.norm(x) @@ -477,26 +684,26 @@ def __call__(self, vision_outputs: Array) -> Array: return x.astype(vision_outputs.dtype) -def make_causal_mask(x: Array, layer_cache: LayerCache, token_type_ids: Array): - _, t = x.shape +def make_causal_mask(layer_cache: LayerCache, token_type_ids: Array): + b, t = token_type_ids.shape c = layer_cache.size - tmp1 = jnp.arange(t) - tmp2 = jnp.arange(c) - my_mask = tmp1[:, None] - tmp2[None, :] >= -layer_cache.cur_ind + seq_arange = jnp.arange(t) + cache_arange = jnp.arange(c) + causal_mask = seq_arange[:, None] - cache_arange[None, :] >= -layer_cache.cur_ind tti = token_type_ids.astype(jnp.bool_) - tmp3 = jnp.concat([tti, jnp.zeros((1, c - t), dtype=jnp.bool_)], axis=-1) - or_mask = tti[:, None, :, None] & tmp3[:, None, None, :] - my_mask = my_mask.astype(jnp.bool_) | or_mask - return my_mask + cache_padded_tti = jnp.concat([tti, jnp.zeros((b, c - t), dtype=jnp.bool_)], axis=-1) + image_or_mask = tti[:, None, :, None] & cache_padded_tti[:, None, None, :] + causal_mask = causal_mask.astype(jnp.bool_) | image_or_mask + return causal_mask -def make_window_mask(x, layer_cache, token_type_ids, slide_size: int): - my_mask = make_causal_mask(x, layer_cache, token_type_ids) - *_, t, c = my_mask.shape - tmp1 = jnp.arange(t) - tmp2 = jnp.arange(c) - slide = tmp1[:, None] - tmp2[None, :] < slide_size - return my_mask & slide +def make_window_mask(layer_cache, token_type_ids, slide_size: int): + causal_mask = make_causal_mask(layer_cache, token_type_ids) + *_, t, c = causal_mask.shape + seq_arange = jnp.arange(t) + cache_arange = jnp.arange(c) + slide = seq_arange[:, None] - cache_arange[None, :] < slide_size + return causal_mask & slide def merge_modalities(img_emb: Array, text_emb: Array, token_mask: Array) -> Array: @@ -523,6 +730,7 @@ def batched_merge_modalities(img_emb: Array, text_emb: Array, token_mask: Array) class Gemma3Model(nnx.Module): def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): self.sliding_window_size = cfg.text_config.sliding_window + self.embed_tokens = Gemma3TextScaledWordEmbedding(cfg.text_config, rngs=rngs) self.vision_tower = SiglipVisionTransformer(cfg.vision_config, rngs=rngs) self.multi_modal_projector = Gemma3MultiModalProjector(cfg, rngs=rngs) self.language_model = Gemma3TextModel(cfg.text_config, rngs=rngs) @@ -531,10 +739,10 @@ def __call__( self, input_ids: Array, pixel_values: Array, cache: Cache, segment_ids: Array, token_type_ids: Array ) -> Array: assert input_ids.shape == token_type_ids.shape - causal_mask = make_causal_mask(input_ids, cache[0], token_type_ids) - sliding_mask = make_window_mask(input_ids, cache[0], token_type_ids, slide_size=self.sliding_window_size) + causal_mask = make_causal_mask(cache[0], token_type_ids) + sliding_mask = make_window_mask(cache[0], token_type_ids, slide_size=self.sliding_window_size) - inputs_embeds = self.language_model.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) # Merge text and images if pixel_values is not None: @@ -545,7 +753,7 @@ def __call__( inputs_embeds = batched_merge_modalities(image_features, inputs_embeds, token_type_ids) out = self.language_model(inputs_embeds, cache, segment_ids, sliding_mask, causal_mask) - out = self.language_model.embed_tokens.weight.attend(out) + out = self.embed_tokens.weight.attend(out) return out @@ -554,4 +762,4 @@ def forward( model: nnx.Module, cache: Cache, input_ids: Array, pixel_values: Array, segment_ids: Array, token_type_ids ) -> tuple[Array, nnx.Cache]: logits = model(input_ids, pixel_values, cache, segment_ids, token_type_ids) - return logits[:, -1:None, :], cache + return logits[:, -1, :], cache diff --git a/bonsai/models/gemma3/params.py b/bonsai/models/gemma3/params.py index c84275c0..8b37fdcd 100644 --- a/bonsai/models/gemma3/params.py +++ b/bonsai/models/gemma3/params.py @@ -47,7 +47,7 @@ def _get_key_and_transform_mapping(): # Mapping st_keys -> (nnx_keys, (permute_rule, reshape_rule)). return { r"^language_model\.model\.embed_tokens\.weight$": ( - r"language_model\.embed_tokens\.weight\.embedding", + r"embed_tokens\.weight\.embedding", Transform.EMBED, ), r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight$": ( @@ -225,6 +225,7 @@ def _stoi(s): # TODO: Update to include sharding def create_gemma3_from_pretrained( file_dir: str, + cfg: model_lib.ModelConfig, *, mesh: jax.sharding.Mesh | None = None, ): @@ -242,7 +243,7 @@ def create_gemma3_from_pretrained( for f in files: tensor_dict |= safetensors.load_file(f) - gemma3 = model_lib.Gemma3Model(model_lib.ModelConfig(), rngs=nnx.Rngs(0)) + gemma3 = model_lib.Gemma3Model(cfg, rngs=nnx.Rngs(0)) graph_def, abs_state = nnx.split(gemma3) jax_state = abs_state.to_pure_dict() diff --git a/bonsai/models/gemma3/tests/run_model.py b/bonsai/models/gemma3/tests/run_model.py index 27185310..8646db44 100644 --- a/bonsai/models/gemma3/tests/run_model.py +++ b/bonsai/models/gemma3/tests/run_model.py @@ -14,28 +14,22 @@ """Run a small inference example for Gemma3.""" +import os import time import jax import jax.numpy as jnp import numpy as np +import torch import tqdm +from huggingface_hub import snapshot_download +from transformers import Gemma3Processor from bonsai.models.gemma3 import modeling, params - -try: - from huggingface_hub import snapshot_download -except Exception: - snapshot_download = None - - -import os - -import torch -from transformers import AutoProcessor +from bonsai.utils import Sampler -def make_input(processor, dtype=torch.float32, msg1=True): +def make_input(processor, dtype=torch.float32, msg1=False): if msg1: messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, @@ -65,41 +59,51 @@ def make_input(processor, dtype=torch.float32, msg1=True): }, ] - out = processor.apply_chat_template( + t_inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ) - out["pixel_values"] = out["pixel_values"].to(dtype=dtype) + t_inputs["pixel_values"] = t_inputs["pixel_values"].to(dtype=dtype) - return out + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + return n_text, n_img, n_tti def run_model(): model_name: str = "google/gemma-3-4b-it" cfg = modeling.ModelConfig() access_token = os.environ["HF_TOKEN"] - processor = AutoProcessor.from_pretrained(model_name, token=access_token, use_fast=False) + processor = Gemma3Processor.from_pretrained(model_name, token=access_token, use_fast=False) model_ckpt_path = snapshot_download(model_name) - bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path) + bonsai_config = modeling.ModelConfig.gemma3_4b() + bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path, bonsai_config) + eot_token_id = processor.tokenizer.convert_tokens_to_ids("") # Make inputs - t_inputs = make_input(processor) - n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) - n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) - n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + n_text, n_img, n_tti = make_input(processor) - gen_steps = 200 + gen_steps = 500 batch_size, num_tokens = n_text.shape cache = modeling.init_cache(cfg, batch_size, num_tokens, gen_steps, jnp.float32) + source_key = jax.random.key(0) + sampler = jax.jit(Sampler(temperature=1.0, top_p=0.8, top_k=10)) + all_tokens = [n_text] - for i in tqdm.trange(gen_steps): + for _ in tqdm.trange(gen_steps): batch_size, num_tokens = n_text.shape segment_ids = jnp.ones((batch_size, num_tokens)) out, cache = modeling.forward(bonsai_model, cache, n_text, n_img, segment_ids, n_tti) - n_text = jnp.argmax(out, axis=-1) + source_key, key = jax.random.split(source_key) + n_text = sampler(out, key=key) + if jnp.all(n_text == eot_token_id): + print("Hit end of token.") + break all_tokens.append(n_text) - n_tti = n_tti[:, -1:None] # this assumes that the input prompt ends with text + n_tti = jnp.zeros((batch_size, 1), dtype=jnp.int32) n_img = None full_tokens = torch.tensor(jnp.concat(all_tokens, axis=1)) diff --git a/bonsai/models/gemma3/tests/test_outputs_gemma3.py b/bonsai/models/gemma3/tests/test_outputs_gemma3.py index 5ecdce38..12bb5731 100644 --- a/bonsai/models/gemma3/tests/test_outputs_gemma3.py +++ b/bonsai/models/gemma3/tests/test_outputs_gemma3.py @@ -7,6 +7,7 @@ import torch from absl.testing import absltest from huggingface_hub import snapshot_download +from jax.sharding import AxisType from jax.typing import DTypeLike from tqdm import trange from transformers import AutoModel, AutoProcessor, AutoTokenizer, Gemma3Model, SiglipModel @@ -17,6 +18,8 @@ from bonsai.models.gemma3 import modeling, params +SKIP_DONE: bool = True + class TestModuleForwardPasses(absltest.TestCase): # using this for faster testing. This way we can avoid reloading the model. @@ -29,6 +32,8 @@ def setUpClass(cls): access_token = os.environ["HF_TOKEN"] cls.processor = AutoProcessor.from_pretrained(cls.model_name, token=access_token, use_fast=False) cls.torch_device = "cpu" + cls.mesh = jax.make_mesh(((1, 1)), ("fsdp", "tp"), axis_types=(AxisType.Explicit, AxisType.Explicit)) + jax.set_mesh(cls.mesh) ## models cls.torch_model = ( @@ -38,16 +43,16 @@ def setUpClass(cls): ) cls.torch_config = cls.torch_model.config - cls.bonsai_config = modeling.ModelConfig() + cls.bonsai_config = modeling.ModelConfig.gemma3_4b() model_ckpt_path = snapshot_download(cls.model_name) - cls.bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path) + cls.bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path, cls.bonsai_config) cls.batch_size = 1 cls.cache_size, cls.gen_steps = 512, 10 def _upgrade_dtypes(self): - self.bonsai_model.language_model.embed_tokens.weight.embedding.value = ( - self.bonsai_model.language_model.embed_tokens.weight.embedding.value.astype(jnp.float32) + self.bonsai_model.embed_tokens.weight.embedding.value = ( + self.bonsai_model.embed_tokens.weight.embedding.value.astype(jnp.float32) ) return @@ -268,7 +273,7 @@ def _init_nnx_cache(self, batch_size: int, token_len: int, generate_steps: int, ) # Vision tests - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_image_emb(self): tm = self.torch_model.model.vision_tower.vision_model.embeddings nm = self.bonsai_model.vision_tower.embeddings @@ -285,7 +290,7 @@ def test_image_emb(self): # (1, 4096, 1152) np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5) - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_siglip_encoder_layer(self): tm = self.torch_model.model.vision_tower.vision_model.encoder.layers[0] nm = self.bonsai_model.vision_tower.encoder.layers[0] @@ -299,7 +304,7 @@ def test_siglip_encoder_layer(self): np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4) - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_vision_model(self): # only have deviations on .0567% of the entries and on order 7e-3 tm = self.torch_model.model.vision_tower @@ -317,11 +322,11 @@ def test_vision_model(self): np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-2, atol=1e-2) # Language tests - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_text_embedding(self): self._upgrade_dtypes() tm = self.torch_model.model.language_model.embed_tokens - nm = self.bonsai_model.language_model.embed_tokens + nm = self.bonsai_model.embed_tokens torch.testing.assert_close(torch.tensor(nm.weight.embedding.value), tm.weight.cpu()) torch.testing.assert_close(torch.tensor(nm.embed_scale), tm.embed_scale.cpu()) @@ -337,7 +342,7 @@ def test_text_embedding(self): np.testing.assert_allclose(ny, ty.detach().cpu().numpy()) - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_attn_projs(self): tm = self.torch_model.model.language_model.layers[0].self_attn nm = self.bonsai_model.language_model.layers[0].self_attn @@ -363,7 +368,7 @@ def test_attn_projs(self): ny = nm.o_proj(nx) np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-4, atol=1e-4, err_msg="o") - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_attn_norms(self): tm = self.torch_model.model.language_model.layers[0].self_attn nm = self.bonsai_model.language_model.layers[0].self_attn @@ -386,14 +391,14 @@ def test_attn_norms(self): ny = nm.k_norm(nx) np.testing.assert_allclose(ny, ty.detach().cpu().numpy(), rtol=1e-5, atol=1e-5, err_msg="k") - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_sin_cos(self): batch_size, seq_len, dim = 2, 10, 256 hidden_states = torch.ones((batch_size, seq_len, dim)) jp = jnp.stack([jnp.arange(seq_len), jnp.arange(seq_len)]) # local uses default - rt = self.bonsai_config.text_config.rope_local_base_freq + rt = self.bonsai_config.text_config.rope_slide_theta js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=1.0) rot_emb = self.torch_model.model.language_model.rotary_emb_local tc, ts = rot_emb(hidden_states, torch.tensor(jp)) @@ -402,7 +407,7 @@ def test_sin_cos(self): torch.testing.assert_close(torch.tensor(jc), tc) # global uses linear - rt = self.bonsai_config.text_config.rope_theta + rt = self.bonsai_config.text_config.rope_full_theta js, jc = modeling._generate_pos_embeddings(jp, dim, rope_theta=rt, factor=8.0) rot_emb = self.torch_model.model.language_model.rotary_emb tc, ts = rot_emb(hidden_states, torch.tensor(jp)) @@ -410,7 +415,7 @@ def test_sin_cos(self): torch.testing.assert_close(torch.tensor(js), ts) torch.testing.assert_close(torch.tensor(jc), tc) - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_text_decoder_layer(self): first_t_inputs = self._make_torch_input() start_t_inputs = self._process_torch_inputs(**first_t_inputs) @@ -431,9 +436,9 @@ def test_text_decoder_layer(self): n_tti = first_t_inputs["token_type_ids"].detach().cpu().numpy() if attn_type == "full_attention": - mask = modeling.make_causal_mask(n_tti, nnx_cache[test_layer], n_tti) + mask = modeling.make_causal_mask(nnx_cache[test_layer], n_tti) else: - mask = modeling.make_window_mask(n_tti, nnx_cache[test_layer], n_tti, 1024) + mask = modeling.make_window_mask(nnx_cache[test_layer], n_tti, 1024) # run models ty = tm(**t_inputs) @@ -453,7 +458,7 @@ def test_text_decoder_layer(self): # multi modal tests - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_multi_modal_projector(self): t_inputs = self._make_torch_input() tm = self.torch_model.model @@ -467,7 +472,7 @@ def test_multi_modal_projector(self): torch.testing.assert_close(torch.tensor(ny), ty, rtol=1e-4, atol=1e-4) - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_text_image_merge(self): nm = self.bonsai_model t_inputs = self._make_torch_input() @@ -477,7 +482,7 @@ def test_text_image_merge(self): t_ans = t_out["inputs_embeds"] tmp = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) - n_text = nm.language_model.embed_tokens(tmp) + n_text = nm.embed_tokens(tmp) # return n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) @@ -489,7 +494,7 @@ def test_text_image_merge(self): np.testing.assert_allclose(n_ans, t_ans.detach().cpu().numpy(), rtol=1e-3, atol=1e-3) - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_text_layers_in_order(self): first_t_inputs = self._make_torch_input() start_t_inputs = self._process_torch_inputs(**first_t_inputs) @@ -516,9 +521,9 @@ def test_text_layers_in_order(self): n_tti = first_t_inputs["token_type_ids"].detach().cpu().numpy() if attn_type == "full_attention": - mask = modeling.make_causal_mask(n_tti, nnx_cache[test_layer], n_tti) + mask = modeling.make_causal_mask(nnx_cache[test_layer], n_tti) else: - mask = modeling.make_window_mask(n_tti, nnx_cache[test_layer], n_tti, 1024) + mask = modeling.make_window_mask(nnx_cache[test_layer], n_tti, 1024) # run models ty = tm(**t_inputs) @@ -532,7 +537,7 @@ def test_text_layers_in_order(self): found_exception = True assert not found_exception, "FOUND EXCEPTION" - # @unittest.skip("Done") + @unittest.skipIf(SKIP_DONE, "Done") def test_masks(self): # Make a really long input so we can test the sliding window # This only tests for the pre-fill stage @@ -556,11 +561,10 @@ def test_masks(self): t_inputs["pixel_values"] = t_inputs["pixel_values"].to(dtype=torch.float32) batch_size, num_tokens = t_inputs["input_ids"].shape - n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) gen_steps = 10 cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, gen_steps) - n_mask = modeling.make_causal_mask(n_text, cache[0], n_tti) + n_mask = modeling.make_causal_mask(cache[0], n_tti) # Full attention t_inputs = self._process_torch_inputs(**t_inputs) @@ -571,7 +575,7 @@ def test_masks(self): # Sliding attention t_mask = t_inputs["attention_mask"]["sliding_attention"] - n_mask = modeling.make_window_mask(n_text, cache[0], n_tti, self.bonsai_config.text_config.sliding_window) + n_mask = modeling.make_window_mask(cache[0], n_tti, self.bonsai_config.text_config.sliding_window) np.testing.assert_allclose(n_mask[:, :, :, :size_for_comp], t_mask.detach().cpu().numpy()) @@ -592,14 +596,12 @@ def test_full_in_order(self): cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) # Get masks - n_causal_mask = modeling.make_causal_mask(n_text, cache[0], n_tti) - n_sliding_mask = modeling.make_window_mask( - n_text, cache[0], n_tti, self.bonsai_config.text_config.sliding_window - ) + n_causal_mask = modeling.make_causal_mask(cache[0], n_tti) + n_sliding_mask = modeling.make_window_mask(cache[0], n_tti, self.bonsai_config.text_config.sliding_window) # text embeds t_inputs_embeds = tm.language_model.embed_tokens(t_inputs["input_ids"]) - n_inputs_embeds = nm.language_model.embed_tokens(n_text) + n_inputs_embeds = nm.embed_tokens(n_text) np.testing.assert_allclose(n_inputs_embeds, t_inputs_embeds.detach().cpu().numpy(), err_msg="text emb") # Vision part @@ -662,7 +664,7 @@ def test_full_in_order(self): assert not found_error, "Found errors in text decoder layers" # NOTE: some errors are expected here since errors compound with layer - # @unittest.skip("Done") + # @unittest.skipIf(SKIP_DONE, "Done") def test_full(self): tm = self.torch_model nm = self.bonsai_model @@ -683,8 +685,35 @@ def test_full(self): @unittest.skip("TODO") def test_full_batched(self): - # TODO: This isn't implemented yet - raise NotImplementedError("Need to test against batched inputs") + tm = self.torch_model + nm = self.bonsai_model + + t_inputs = self._make_torch_input() + + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + # Test simple batching + n_img = jnp.concat([n_img, n_img]) + n_text = jnp.concat([n_text, n_text]) + n_tti = jnp.concat([n_tti, n_tti]) + + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + ny = nm(n_text, n_img, cache, segment_ids, n_tti) + ty = tm(**t_inputs) + + # for i in range(2): + torch.testing.assert_close(torch.tensor(ny)[0:1], ty.logits, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(torch.tensor(ny)[1:2], ty.logits, rtol=5e-2, atol=5e-2) + + raise NotImplementedError("Need to get more complex batched inputs working") + # When doing batching, prompts have >= 0 images (not all same) -> change batched_merge_modalities + # for this, we might also need to keep track of where images came from + # We also need to update the left padding to deal with different padding for each prompt if __name__ == "__main__": diff --git a/bonsai/models/gemma3/tests/test_sharding_gemma3.py b/bonsai/models/gemma3/tests/test_sharding_gemma3.py new file mode 100644 index 00000000..e9405bb1 --- /dev/null +++ b/bonsai/models/gemma3/tests/test_sharding_gemma3.py @@ -0,0 +1,96 @@ +import os +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from absl.testing import absltest +from huggingface_hub import snapshot_download +from jax import P, make_mesh, set_mesh +from jax.sharding import AxisType +from jax.typing import DTypeLike +from tqdm import trange +from transformers import AutoProcessor + +from bonsai.models.gemma3 import modeling, params + +# artificial cpu devices +jax.config.update("jax_num_cpu_devices", 2) + + +class TestSharding(absltest.TestCase): + # using this for faster testing. This way we can avoid reloading the model. + # Make sure not to modify the Gemma3 model in inconsistent ways between tests. + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_name: str = "google/gemma-3-4b-it" + # self.model_name: str = "google/gemma-3-270m" # This is text only + access_token = os.environ["HF_TOKEN"] + cls.processor = AutoProcessor.from_pretrained(cls.model_name, token=access_token, use_fast=False) + cls.torch_device = "cpu" + + fsdp, tp = modeling.ShardMode.FSDP.value, modeling.ShardMode.TP.value + + cls.mesh = jax.make_mesh(((2, 1)), (fsdp, tp), axis_types=(AxisType.Explicit, AxisType.Explicit)) + jax.set_mesh(cls.mesh) + + cls.bonsai_config = modeling.ModelConfig.gemma3_4b() + model_ckpt_path = snapshot_download(cls.model_name) + cls.bonsai_model = params.create_gemma3_from_pretrained(model_ckpt_path, cls.bonsai_config) + + cls.batch_size = 1 + cls.cache_size, cls.gen_steps = 512, 10 + + def _make_torch_input(self): + # returns model inputs: + # KEY SHAPE DTYPE + # input_ids torch.Size([1, 281]) int64 + # attention_mask torch.Size([1, 281]) int64 + # token_type_ids torch.Size([1, 281]) int64 + # pixel_values torch.Size([1, 3, 896, 896]) bfloat16 -> float32 + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + }, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ] + + out = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + out["pixel_values"] = out["pixel_values"].to(dtype=torch.float32) + + return {k: v.to(device=self.torch_device) for k, v in out.items()} + + def test_full(self): + nm = self.bonsai_model + + t_inputs = self._make_torch_input() + + n_img = jnp.array(np.permute_dims(t_inputs["pixel_values"].detach().cpu().numpy(), (0, 2, 3, 1))) + n_text = jnp.array(t_inputs["input_ids"].detach().cpu().numpy()) + n_tti = jnp.array(t_inputs["token_type_ids"].detach().cpu().numpy()) + + # Test simple batching + n_img = jnp.concat([n_img, n_img]) + n_text = jnp.concat([n_text, n_text]) + n_tti = jnp.concat([n_tti, n_tti]) + + batch_size, num_tokens = n_text.shape + segment_ids = jnp.ones((batch_size, num_tokens)) + cache = modeling.init_cache(self.bonsai_config, batch_size, num_tokens, 1, jnp.float32) + + nm(n_text, n_img, cache, segment_ids, n_tti) + + +if __name__ == "__main__": + absltest.main()