-
Notifications
You must be signed in to change notification settings - Fork 0
Add Experimental Support for Gemma variants [1B, 27B] #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
265f912 to
6ba3246
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds experimental support for Gemma variants (1B and 27B) while refactoring the Gemma3 model architecture. The implementation moves from a single Gemma3-4B specific structure to a generalized Gemma3 architecture that supports multiple model sizes with updated attention mechanisms including sliding window attention and enhanced multimodal capabilities.
Key Changes
- Added support for gemma-3-1b and gemma-3-27b model configurations with device-specific parameters
- Refactored attention mechanism to use causal masks and sliding window patterns
- Enhanced multimodal support with improved vision-text integration
- Updated RMSNorm implementation with distributed computation capabilities
Reviewed Changes
Copilot reviewed 43 out of 43 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| models/tt_transformers/tt/model_config.py | Added configuration support for Gemma 1B and 27B variants with device-specific parameters |
| models/experimental/gemma3/tt/text_model.py | Refactored Gemma3Transformer with enhanced attention masking and multimodal support |
| models/experimental/gemma3/tt/attention.py | Updated attention mechanism with causal mask support and sliding window functionality |
| models/experimental/gemma3/tt/decoder.py | Enhanced TransformerBlock with improved attention routing and residual connections |
| models/experimental/gemma3/tt/mlp.py | Updated MLP with device scaling and distributed computation |
Comments suppressed due to low confidence (3)
models/experimental/gemma3/tt/text_model.py:1
- [nitpick] Returning None as the 6th element without clear documentation makes the return tuple unclear. Consider using a named tuple or adding a comment explaining what this None represents.
"""
models/experimental/gemma3/tt/text_model.py:1
- [nitpick] Adding **kwargs to forward methods without documenting the expected kwargs could lead to confusion and potential misuse. Consider explicitly defining the expected parameters or adding documentation.
"""
models/experimental/gemma3/tt/text_model.py:1
- [nitpick] Adding **kwargs to forward methods without documenting the expected kwargs could lead to confusion and potential misuse. Consider explicitly defining the expected parameters or adding documentation.
"""
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| x_1BSH, | ||
| device=self.mesh_device, | ||
| dtype=ttnn.bfloat16, | ||
| dtype=ttnn.bfloat8_b, |
Copilot
AI
Sep 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing from ttnn.bfloat16 to ttnn.bfloat8_b reduces precision which may impact model accuracy. Ensure this change has been validated through testing.
| dtype=ttnn.bfloat8_b, | |
| dtype=ttnn.bfloat16, |
jschuhmacher
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely done with the review, but I'll post these already.
| if is_gemma3: | ||
| self.rms_norm_add_unit_offset = True | ||
| self.embed_scale = self.dim**0.5 | ||
| self.sliding_window = 512 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sliding window parameter is present in the HF configuration (1024 by default). It would be better to get it from there.
| Returns the number of tokens per chunk, accounting for the extra class token | ||
| """ | ||
| return (self.vision_chunk_size // self.vision_patch_size) ** 2 + 1 | ||
| return (self.image_size // self.vision_patch_size) ** 2 + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is common code. Have you verified other uses of vision_chunk_size in e.g. Llama models? I'd suggest keeping the previous calculation intact, and making a new one for image chunks.
| if hasattr(model.model, "rotary_emb_local") and model.model.rotary_emb_local is not None: | ||
| wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb, model.model.rotary_emb_local) | ||
| else: | ||
| rotary_emb_local = None | ||
| wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb, rotary_emb_local) | ||
| wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about?
rotary_emb_local = getattr(model.model, "rotary_emb_local", None)
wrapper = HfDecoderWrapper(layer, self.head_dim, model.model.rotary_emb, rotary_emb_local=rotary_emb_local)| num_layers=None, | ||
| ): | ||
| from models.tt_transformers.tt.model import Transformer | ||
| if "HF_MODEL" in os.environ and "gemma-3" in os.environ["HF_MODEL"].lower(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is common code, also used for models that are supposed to be more production quality. Loading a module from experimental does not seem like the right thing to do here. Could you think of another way of accomplishing this? For inspiration, maybe look at gemma3/demo/text_demo.py#L71.
models/tt_transformers/tt/common.py
Outdated
| ), | ||
| ) | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest moving all the masking utilities to their own module. And since they are Gemma specific, they could be located with the rest of the Gemma code.
models/tt_transformers/tt/model.py
Outdated
| chunk_start_idx=None, | ||
| get_last_token=-1, | ||
| kv_cache=None, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes it hard to see which arguments could be expected. Why not list the potential arguments with default values above?
models/tt_transformers/tt/model.py
Outdated
| page_table=None, | ||
| kv_cache=None, | ||
| argmax_on_device=False, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes it hard to see which arguments could be expected. Why not list the potential arguments with default values above?
|
|
||
|
|
||
| # SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC | ||
| # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment for all the files: it seems Tenstorrent prefers the
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
text instead.
| model_args.rope_theta, | ||
| model_args.rope_scaling, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're loading the first layer, which is a sliding attention layer. The rotary setup and precompute_freqs are set up for a global attention layer, so that should not be matching the reference output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe another test with a global attention layer is also interesting?
|
|
||
| seqlen = 1 | ||
|
|
||
| cos, sin = precompute_freqs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, the rope setup and layer seem to be mismatching
| if mode == "decode" and not self.args.is_galaxy: | ||
| x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"], activation_dtype) | ||
| elif activation_dtype is not None and x.dtype != activation_dtype: | ||
| x = ttnn.typecast(x, activation_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick, naming the masks first makes it a bit easier to read, for instance:
| x = ttnn.typecast(x, activation_dtype) | |
| causal_mask = None | |
| if attention_masks is not None: | |
| sliding_causal_mask, global_causal_mask = attention_masks | |
| causal_mask = sliding_causal_mask if getattr(layer.attention, "is_sliding", False) else global_causal_mask |
fafc552 to
7ee5ddc
Compare
Ticket
Link to JIRA Issue
Problem description
Enable Support for Gemma-3-27b-it Model.
What's changed
Added support for gemma-3-27b-it model
Updated model_config.py to support gemma-3-27b-it , including end-of-sequence (EoS) token handling.
Updated load_checkpoints.py to support gemma-3-27b-it weight loading.
Modified apply_scaling logic to handle both LLaMA and gemma-3-27b-it model.
Checklist