Skip to content

Conversation

@MohammedTaherMcW
Copy link

Ticket

Link to JIRA Issue

Problem description

Enable Support for Gemma-3-27b-it Model.

What's changed

Added support for gemma-3-27b-it model
Updated model_config.py to support gemma-3-27b-it , including end-of-sequence (EoS) token handling.
Updated load_checkpoints.py to support gemma-3-27b-it weight loading.
Modified apply_scaling logic to handle both LLaMA and gemma-3-27b-it model.

Checklist

@MohammedTaherMcW MohammedTaherMcW force-pushed the mcw/gemma_3_27b/pr_1_experimental branch from 265f912 to 6ba3246 Compare August 27, 2025 19:15
@jennychristopher jennychristopher changed the title Add Experimental Support for Gemma-3-27b-it Add Experimental Support for Gemma variants [1B, 27b] Sep 12, 2025
@jennychristopher jennychristopher changed the title Add Experimental Support for Gemma variants [1B, 27b] Add Experimental Support for Gemma variants [1B, 27B] Sep 12, 2025
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds experimental support for Gemma variants (1B and 27B) while refactoring the Gemma3 model architecture. The implementation moves from a single Gemma3-4B specific structure to a generalized Gemma3 architecture that supports multiple model sizes with updated attention mechanisms including sliding window attention and enhanced multimodal capabilities.

Key Changes

  • Added support for gemma-3-1b and gemma-3-27b model configurations with device-specific parameters
  • Refactored attention mechanism to use causal masks and sliding window patterns
  • Enhanced multimodal support with improved vision-text integration
  • Updated RMSNorm implementation with distributed computation capabilities

Reviewed Changes

Copilot reviewed 43 out of 43 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
models/tt_transformers/tt/model_config.py Added configuration support for Gemma 1B and 27B variants with device-specific parameters
models/experimental/gemma3/tt/text_model.py Refactored Gemma3Transformer with enhanced attention masking and multimodal support
models/experimental/gemma3/tt/attention.py Updated attention mechanism with causal mask support and sliding window functionality
models/experimental/gemma3/tt/decoder.py Enhanced TransformerBlock with improved attention routing and residual connections
models/experimental/gemma3/tt/mlp.py Updated MLP with device scaling and distributed computation
Comments suppressed due to low confidence (3)

models/experimental/gemma3/tt/text_model.py:1

  • [nitpick] Returning None as the 6th element without clear documentation makes the return tuple unclear. Consider using a named tuple or adding a comment explaining what this None represents.
"""

models/experimental/gemma3/tt/text_model.py:1

  • [nitpick] Adding **kwargs to forward methods without documenting the expected kwargs could lead to confusion and potential misuse. Consider explicitly defining the expected parameters or adding documentation.
"""

models/experimental/gemma3/tt/text_model.py:1

  • [nitpick] Adding **kwargs to forward methods without documenting the expected kwargs could lead to confusion and potential misuse. Consider explicitly defining the expected parameters or adding documentation.
"""

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

x_1BSH,
device=self.mesh_device,
dtype=ttnn.bfloat16,
dtype=ttnn.bfloat8_b,
Copy link

Copilot AI Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing from ttnn.bfloat16 to ttnn.bfloat8_b reduces precision which may impact model accuracy. Ensure this change has been validated through testing.

Suggested change
dtype=ttnn.bfloat8_b,
dtype=ttnn.bfloat16,

Copilot uses AI. Check for mistakes.
@flexaihq flexaihq deleted a comment from Copilot AI Sep 24, 2025
@flexaihq flexaihq deleted a comment from Copilot AI Sep 24, 2025
@flexaihq flexaihq deleted a comment from Copilot AI Sep 24, 2025
@flexaihq flexaihq deleted a comment from Copilot AI Sep 24, 2025
@flexaihq flexaihq deleted a comment from Copilot AI Sep 24, 2025
Copy link

@jschuhmacher jschuhmacher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely done with the review, but I'll post these already.

if is_gemma3:
self.rms_norm_add_unit_offset = True
self.embed_scale = self.dim**0.5
self.sliding_window = 512

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sliding window parameter is present in the HF configuration (1024 by default). It would be better to get it from there.

Returns the number of tokens per chunk, accounting for the extra class token
"""
return (self.vision_chunk_size // self.vision_patch_size) ** 2 + 1
return (self.image_size // self.vision_patch_size) ** 2 + 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is common code. Have you verified other uses of vision_chunk_size in e.g. Llama models? I'd suggest keeping the previous calculation intact, and making a new one for image chunks.

Comment on lines 2437 to 2588
if hasattr(model.model, "rotary_emb_local") and model.model.rotary_emb_local is not None:
wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb, model.model.rotary_emb_local)
else:
rotary_emb_local = None
wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb, rotary_emb_local)
wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about?

rotary_emb_local = getattr(model.model, "rotary_emb_local", None)
wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb, rotary_emb_local=rotary_emb_local)

num_layers=None,
):
from models.tt_transformers.tt.model import Transformer
if "HF_MODEL" in os.environ and "gemma-3" in os.environ["HF_MODEL"].lower():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is common code, also used for models that are supposed to be more production quality. Loading a module from experimental does not seem like the right thing to do here. Could you think of another way of accomplishing this? For inspiration, maybe look at gemma3/demo/text_demo.py#L71.

),
)


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest moving all the masking utilities to their own module. And since they are Gemma specific, they could be located with the rest of the Gemma code.

chunk_start_idx=None,
get_last_token=-1,
kv_cache=None,
**kwargs,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes it hard to see which arguments could be expected. Why not list the potential arguments with default values above?

page_table=None,
kv_cache=None,
argmax_on_device=False,
**kwargs,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes it hard to see which arguments could be expected. Why not list the potential arguments with default values above?



# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment for all the files: it seems Tenstorrent prefers the

# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC

text instead.

model_args.rope_theta,
model_args.rope_scaling,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're loading the first layer, which is a sliding attention layer. The rotary setup and precompute_freqs are set up for a global attention layer, so that should not be matching the reference output?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe another test with a global attention layer is also interesting?


seqlen = 1

cos, sin = precompute_freqs(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, the rope setup and layer seem to be mismatching

if mode == "decode" and not self.args.is_galaxy:
x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"], activation_dtype)
elif activation_dtype is not None and x.dtype != activation_dtype:
x = ttnn.typecast(x, activation_dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, naming the masks first makes it a bit easier to read, for instance:

Suggested change
x = ttnn.typecast(x, activation_dtype)
causal_mask = None
if attention_masks is not None:
sliding_causal_mask, global_causal_mask = attention_masks
causal_mask = sliding_causal_mask if getattr(layer.attention, "is_sliding", False) else global_causal_mask

@MohammedTaherMcW MohammedTaherMcW force-pushed the mcw/gemma_3_27b/pr_1_experimental branch from fafc552 to 7ee5ddc Compare October 13, 2025 10:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants