From 2518c22d6976a1cbcbd3d1bf7e327f25173331ca Mon Sep 17 00:00:00 2001 From: Ganesh Nanduru Date: Fri, 5 Dec 2025 16:45:28 -0800 Subject: [PATCH] zaya rebase --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/zaya.md | 80 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/zaya/__init__.py | 28 + .../models/zaya/configuration_zaya.py | 548 +++++ src/transformers/models/zaya/modeling_zaya.py | 2082 +++++++++++++++++ src/transformers/models/zaya/modular_zaya.py | 2022 ++++++++++++++++ tests/models/zaya/__init__.py | 0 tests/models/zaya/test_modeling_zaya.py | 172 ++ utils/check_docstrings.py | 3 + 11 files changed, 4941 insertions(+) create mode 100644 docs/source/en/model_doc/zaya.md create mode 100644 src/transformers/models/zaya/__init__.py create mode 100755 src/transformers/models/zaya/configuration_zaya.py create mode 100755 src/transformers/models/zaya/modeling_zaya.py create mode 100755 src/transformers/models/zaya/modular_zaya.py create mode 100644 tests/models/zaya/__init__.py create mode 100644 tests/models/zaya/test_modeling_zaya.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 451f99eec071..c1d6f4e63420 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -736,6 +736,8 @@ title: Zamba - local: model_doc/zamba2 title: Zamba2 + - local: model_doc/zaya + title: Zaya title: Text models - sections: - local: model_doc/aimv2 diff --git a/docs/source/en/model_doc/zaya.md b/docs/source/en/model_doc/zaya.md new file mode 100644 index 000000000000..0f6d1f37dfa2 --- /dev/null +++ b/docs/source/en/model_doc/zaya.md @@ -0,0 +1,80 @@ +--- +license: apache-2.0 +library_name: transformers +--- + +# Model Card for ZAYA1-base + +ZAYA1 is an 800m active/8.3B total parameter MoE model, and the first trained entirely end-to-end on AMD’s hardware, software, and networking stack. + +Our ZAYA1 base model benchmark performance is extremely competitive with the SoTA Qwen3 series of models of comparable scale, and outperforms comparable western open-source models such as SmolLM3, and Phi4. ZAYA1-base excels especially at complex and challenging mathematical and STEM reasoning tasks, nearly matching the performance of SoTA Qwen3 thinking models under high pass@k settings even prior to explicit post-training for reasoning, and exceeds other strong reasoning models such as Phi4-reasoning, and Deepseek-R1-Distill. + +Details of our pretraining efforts, hardware specific optimizations, and ZAYA1 base model benchmarks are described in the [accompanying technical report](https://arxiv.org/abs/2511.17127). + + +## Model Details + +ZAYA1's architecture includes several innovations developed at Zyphra. These include: + +- **Compressed Convolutional Attention (CCA)**: [This novel attention](https://arxiv.org/abs/2510.04476) mechanism performs attention entirely in the latent space enabling significant reductions in parameter count, prefill compute, and KV cache size compared to alternative attention mechanisms, while also being more performant in loss/flop. +- **ZAYA1 Router**: The ZAYA1 router makes fundamental improvements to the linear router used in almost all existing large-scale MoE models. The ZAYA1 router replaces the linear with a downprojection followed by a depth-mixing EDA layer then a three-layer MLP per expert to add significant nonlinear expressivity to the router. +- **Residual Scaling**: We add learnable scalar gates and biases to the residual stream and the outputs of each block. This provides a lightweight method to allow the model to carefully control its own norm and degree of forgetting across depth. + + +![zaya_arch](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/Ih8RnOPNbtRzaVcH16ar-.png) + +ZAYA1-base uses the [Gemma3](https://ai.google.dev/gemma/terms) tokenizer. + + +## Performance + +ZAYA1-base performs extremely competitively against other base models of a similar and even greater scale. + +![mmlu_pro_vs_ttft](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/nyWieuzXks9H4GM71XAzn.png) + +![Screenshot 2025-11-20 at 00.44.44](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/tsdgc4KWWs4SXfo4orOp4.png) + +## Quick start + +### Prerequisites + +To use ZAYA1, install `zaya` branch from our fork of `transformers` library, which is based on the v4.57.1 of `transformers`: +```bash +pip install "transformers @ git+https://github.com/Zyphra/transformers.git@zaya" +``` + +The command above relies on requirements for `transformers v4.57.1` being installed in your environment. If you're installing in a fresh Python environment, you might want to specify a specific extra, like `[dev-torch]`, to install all the dependencies: +```bash +pip install "transformers[dev-torch] @ git+https://github.com/Zyphra/transformers.git@zaya" +``` + + +### Inference + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +tokenizer = AutoTokenizer.from_pretrained("Zyphra/ZAYA1-base") +model = AutoModelForCausalLM.from_pretrained("Zyphra/ZAYA1-base", device_map="cuda", dtype=torch.bfloat16) + +input_text = "What factors contributed to the fall of the Roman Empire?" +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +outputs = model.generate(**input_ids, max_new_tokens=100) +print(tokenizer.decode(outputs[0])) +``` + +## ZayaConfig + +[[autodoc]] ZayaConfig + +## ZayaModel + +[[autodoc]] ZayaModel + - forward + +## ZayaForCausalLM + +[[autodoc]] ZayaForCausalLM + - forward diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 3df37eeb3468..a4b5d8ff0a4e 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -454,6 +454,7 @@ ("yoso", "YosoConfig"), ("zamba", "ZambaConfig"), ("zamba2", "Zamba2Config"), + ("zaya", "ZayaConfig"), ("zoedepth", "ZoeDepthConfig"), ] ) @@ -913,6 +914,7 @@ ("yoso", "YOSO"), ("zamba", "Zamba"), ("zamba2", "Zamba2"), + ("zaya", "Zaya"), ("zoedepth", "ZoeDepth"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index dd4997392617..584bd9a242f5 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -436,6 +436,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("yoso", "YosoModel"), ("zamba", "ZambaModel"), ("zamba2", "Zamba2Model"), + ("zaya", "ZayaModel"), ] ) @@ -757,6 +758,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("xmod", "XmodForCausalLM"), ("zamba", "ZambaForCausalLM"), ("zamba2", "Zamba2ForCausalLM"), + ("zaya", "ZayaForCausalLM"), ] ) diff --git a/src/transformers/models/zaya/__init__.py b/src/transformers/models/zaya/__init__.py new file mode 100644 index 000000000000..54cc0c89f303 --- /dev/null +++ b/src/transformers/models/zaya/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 Zyphra and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_zaya import * + from .modeling_zaya import * + +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/zaya/configuration_zaya.py b/src/transformers/models/zaya/configuration_zaya.py new file mode 100755 index 000000000000..58889ea59310 --- /dev/null +++ b/src/transformers/models/zaya/configuration_zaya.py @@ -0,0 +1,548 @@ +# coding=utf-8 +# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Zaya model.""" + +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ZayaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ZayaModel`]. It is used to instantiate an + Zaya model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of [Zyphra/ZAYA1-base](https://huggingface.co/Zyphra/ZAYA1-base) + or [Zyphra/ZAYA1-reasoning-base](https://huggingface.co/Zyphra/ZAYA1-reasoning-base). + + Example: + ```python + >>> from transformers import ZayaModel, ZayaConfig + + >>> # Initializing an Zaya configuration + >>> configuration = ZayaConfig() + + >>> # Initializing a model from the ZAYA1-base style configuration + >>> model = ZayaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "zaya" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + cca=True, + cca_num_q_heads=[ + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + 8, + 0, + ], + num_query_groups_list=[ + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + 2, + 0, + ], + use_cache=True, + tie_word_embeddings: Optional[bool] = True, + attention_bias=False, + lm_head_bias=False, + vocab_size=262272, + hidden_size=2048, + ffn_hidden_size_list=[ + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + 0, + 4096, + ], + num_hidden_layers=120, + num_attention_heads=16, + activation_func="swiglu", + max_position_embeddings=4096, + norm_epsilon=1e-05, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + rope_theta=10000, + attention_dropout=0.0, + moe_router_topk=1, + zaya_layers=[ + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + "a", + 16, + ], + zaya_mlp_expansion=[ + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + 0, + 256, + ], + zaya_use_mod=True, + zaya_use_eda=True, + add_bias_linear=False, + gated_linear_unit=True, + scale_residual_merge=True, + residual_in_fp32=False, + bias_activation_fusion=True, + activation_func_fp8_input_store=False, + sliding_window=None, + partial_rotary_factor=0.5, + num_key_value_heads=2, + _attn_implementation="eager", + rope_parameters=None, + **kwargs, + ): + self.cca = cca + self.cca_num_q_heads = cca_num_q_heads + self.num_query_groups_list = num_query_groups_list + self.use_cache = use_cache + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size_list = ffn_hidden_size_list + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + assert self.hidden_size % self.num_attention_heads == 0 + self.kv_channels = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.activation_func = activation_func + self.max_position_embeddings = max_position_embeddings + self.norm_epsilon = norm_epsilon + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.attention_dropout = attention_dropout + self.moe_router_topk = moe_router_topk + self.zaya_layers = zaya_layers + self.zaya_mlp_expansion = zaya_mlp_expansion + self.zaya_use_mod = zaya_use_mod + self.zaya_use_eda = zaya_use_eda + self.add_bias_linear = add_bias_linear + self.gated_linear_unit = gated_linear_unit + self.scale_residual_merge = scale_residual_merge + self.residual_in_fp32 = residual_in_fp32 + self.bias_activation_fusion = bias_activation_fusion + self.activation_func_fp8_input_store = activation_func_fp8_input_store + self.sliding_window = sliding_window + self.partial_rotary_factor = partial_rotary_factor + # self.rope_theta = rope_theta + self.num_key_value_heads = num_key_value_heads + self._attn_implementation = _attn_implementation + self.rope_parameters = { + "rope_theta": rope_theta, + "rope_type": "linear", + "factor": 1.0, + } + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["ZayaConfig"] diff --git a/src/transformers/models/zaya/modeling_zaya.py b/src/transformers/models/zaya/modeling_zaya.py new file mode 100755 index 000000000000..421d627d581d --- /dev/null +++ b/src/transformers/models/zaya/modeling_zaya.py @@ -0,0 +1,2082 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zaya/modular_zaya.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zaya.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_zaya import ZayaConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ZayaConfig" + + +jit_fuser = torch.jit.script + + +def swiglu(y): + y_1, y_2 = torch.chunk(y, 2, -1) + return F.silu(y_1) * y_2 + + +@jit_fuser +def bias_swiglu(y, bias): + y = y + bias + return swiglu(y) + + +# @jit_fuser + + +def swiglu_back(g, y): + y_1, y_2 = torch.chunk(y, 2, -1) + return torch.cat( + ( + g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, + g * F.silu(y_1), + ), + -1, + ) + + +@jit_fuser +def bias_swiglu_back(g, y, bias): + y = y + bias + return swiglu_back(g, y) + + +class BiasSwiGLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias, fp8_input_store): + input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + ctx.save_for_backward(input_for_backward, bias) + ctx.ori_input_dtype = input.dtype + ctx.fp8_input_store = fp8_input_store + return bias_swiglu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + tmp = bias_swiglu_back(grad_output, input, bias) + return tmp, tmp, None + + +class SwiGLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, fp8_input_store): + input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + ctx.save_for_backward(input_for_backward) + ctx.ori_input_dtype = input.dtype + ctx.fp8_input_store = fp8_input_store + return swiglu(input) + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_tensors[0] + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + tmp = swiglu_back(grad_output, input) + return tmp, None + + +class ZayaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: ZayaConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[ZayaConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class ZayaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + ZayaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class ZayaDynamicCache(DynamicCache): + """ + Cache that includes both the KV cache and the CCA cache. + """ + + def __init__( + self, + config: ZayaConfig, + batch_size: int, + conv_kernel_size: int = 2, + dtype: torch.dtype = torch.float16, + device: Optional[str] = None, + ): + super().__init__() + self.config = config + self.batch_size = batch_size + self.dtype = dtype + self.device = device + num_k_heads = config.num_query_groups_list[0] + num_q_heads = config.cca_num_q_heads[0] + head_dim = config.hidden_size // config.num_attention_heads + self.conv_kernel_size = conv_kernel_size + self.num_layers = len(config.zaya_layers) + self.latent_k_dim = num_k_heads * head_dim + self.latent_q_dim = num_q_heads * head_dim + self.in_out_ch = self.latent_k_dim + self.latent_q_dim + self.has_previous_state = False + + self.conv_states = torch.zeros( + self.num_layers, + batch_size, + self.in_out_ch, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + + self.prev_hs = torch.zeros( + self.num_layers, + batch_size, + config.hidden_size, + device=device, + dtype=dtype, + ) + + def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor: + if not self.has_previous_state: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.prev_hs.zero_() + + +class CCA(nn.Module): + def __init__( + self, + config: ZayaConfig, + cca_num_kv_heads: int = 2, + cca_num_q_heads: int = 8, + cca_num_heads: int = 16, + hidden_size: Optional[int] = None, + cca_time0: int = 2, + cca_time1: int = 2, + layer_number: int = 0, + ): + super().__init__() + self.debug_level = 4 + self.config = config + self.layer_number = layer_number + + # Use the model's true hidden size unless explicitly overridden. + self.hidden_size = int(hidden_size or config.hidden_size) + + self.cca_time0 = cca_time0 + self.cca_time1 = cca_time1 + self.padding0 = cca_time0 - 1 + self.padding1 = cca_time1 - 1 + self.total_padding = self.padding0 + self.padding1 + + self.num_kv_heads = int(cca_num_kv_heads) + self.num_q_heads = int(cca_num_q_heads) + self.num_heads = int(cca_num_heads) + + # Geometry + self.head_dim = self.hidden_size // self.num_heads + self.latent_k_dim = self.num_kv_heads * self.head_dim + self.latent_q_dim = self.num_q_heads * self.head_dim + self.sqrt_head_dim = float(np.sqrt(self.head_dim)) + self.gqa_groups = self.num_q_heads // self.num_kv_heads + assert self.num_q_heads % self.num_kv_heads == 0, "q_heads must be a multiple of k_heads" + assert (self.latent_k_dim + self.latent_q_dim) == (self.num_kv_heads + self.num_q_heads) * self.head_dim + + # Projections + self.linear_q = nn.Linear(self.hidden_size, self.latent_q_dim, bias=self.config.attention_bias) + self.linear_k = nn.Linear(self.hidden_size, self.latent_k_dim, bias=self.config.attention_bias) + self.val_proj1 = nn.Linear(self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias) + self.val_proj2 = nn.Linear(self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias) + + # Depthwise + grouped conv along sequence + in_out_ch = self.latent_k_dim + self.latent_q_dim + self.conv_qk = nn.Sequential( + nn.Conv1d( + in_channels=in_out_ch, + out_channels=in_out_ch, + kernel_size=self.cca_time0, + groups=in_out_ch, + padding=0, + stride=1, + ), + nn.Conv1d( + in_channels=in_out_ch, + out_channels=in_out_ch, + kernel_size=self.cca_time1, + groups=(self.num_kv_heads + self.num_q_heads), + padding=0, + stride=1, + ), + ) + + # Per-k head temperature + self.temp = nn.Parameter(torch.zeros(self.num_kv_heads)) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[ZayaDynamicCache], + cca_mask, + ): + """ + hidden_states: [B, S, E] (HF layout) + returns: + query: [B, S, num_q_heads*head_dim] + key : [B, S, num_k_heads*head_dim] + value: [B, S, num_k_heads*head_dim] + """ + if cca_mask is not None and hidden_states.shape[1] > 1: + # Only applying in prefill + dtype = hidden_states.dtype + hidden_states = (hidden_states * cca_mask[:, :, None]).to(dtype) + + # ---- Switch to [S, B, H] ---- + hs = hidden_states.transpose(0, 1).contiguous() # [S, B, H] + # Time-shifted stream for v2 (pad one at the front along sequence) + hs_d = F.pad(hs[:-1], pad=(0, 0, 0, 0, 1, 0)) # [S, B, H] + + # Q/K in the full space + q = self.linear_q(hs) # [S, B, latent_q_dim] + k = self.linear_k(hs) # [S, B, latent_k_dim] + qk_packed0 = torch.cat([q, k], dim=-1) # [S, B, latent_q + latent_k] + + # Pre-mean tensors in head form (for "qk_mean_{q,k}" calc) + query_pre = qk_packed0[..., : self.latent_q_dim].view( + *qk_packed0.shape[:2], self.num_q_heads, self.head_dim + ) # [S, B, qh, dh] + + key_pre = qk_packed0[..., self.latent_q_dim :].view( + *qk_packed0.shape[:2], self.num_kv_heads, self.head_dim + ) # [S, B, kh, dh] + key_pre = ( + key_pre.unsqueeze(-2) + .repeat(1, 1, 1, self.gqa_groups, 1) + .view(*qk_packed0.shape[:2], self.num_q_heads, self.head_dim) + ) # [S, B, qh, dh] + + # Means for residual mixing + qk_mean_q = (query_pre + key_pre) / 2 + qk_mean_k = qk_mean_q.view(*qk_mean_q.shape[:2], self.num_kv_heads, self.gqa_groups, -1).mean(dim=-2) + + if past_key_values is not None: + if past_key_values.has_previous_state: + # Generation + qk_packed0 = qk_packed0.transpose(0, 1) # [B, 1, H] + qk_packed0_cached = past_key_values.conv_states[self.layer_number] # [B, H, 2] + qk_packed0_cat = torch.cat([qk_packed0_cached, qk_packed0.transpose(1, 2)], dim=-1) # [B, H, 3] + qk_packed3 = self.conv_qk(qk_packed0_cat).permute(2, 0, 1) # [S, B, E] + past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_packed0) # [B, H, 2] + + else: + # Prefill + qk_packed0_transposed = qk_packed0.permute(1, 2, 0) # [S, B, H] -> [B, H, S] + conv_states = nn.functional.pad( + qk_packed0_transposed, + ( + past_key_values.conv_kernel_size - qk_packed0_transposed.shape[-1], + 0, + ), + ) + past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=conv_states) + # Convs over sequence: [S, B, E] -> [B, E, S] -> pad -> conv -> + # [S, B, E] + qk_packed1 = qk_packed0.permute(1, 2, 0) # [B, E, S] + qk_packed2 = F.pad(qk_packed1, (self.total_padding, 0)) + qk_packed3 = self.conv_qk(qk_packed2).permute(2, 0, 1) # [S, B, E] + + else: + # Convs over sequence: [S, B, E] -> [B, E, S] -> pad -> conv -> [S, + # B, E] + qk_packed1 = qk_packed0.permute(1, 2, 0) # [B, E, S] + qk_packed2 = F.pad(qk_packed1, (self.total_padding, 0)) + qk_packed3 = self.conv_qk(qk_packed2).permute(2, 0, 1) # [S, B, E] + + # Build queries/keys from conv output + means + query = ( + qk_packed3[..., : self.latent_q_dim].view(*qk_packed3.shape[:2], self.num_q_heads, self.head_dim) + + qk_mean_q + ) # [S, B, qh, dh] + + key = ( + qk_packed3[..., self.latent_q_dim :].view(*qk_packed3.shape[:2], self.num_kv_heads, self.head_dim) + + qk_mean_k + ) # [S, B, kh, dh] + + # Values from the two time streams + v1 = self.val_proj1(hs) # [S, B, latent_k_dim/2] + if past_key_values is not None: + if past_key_values.has_previous_state: + # Generation + # [B, H] + hs_d = past_key_values.prev_hs[self.layer_number].clone() + hs_d = hs_d.unsqueeze(0) # [1, B, H] + past_key_values.prev_hs[self.layer_number].copy_(hs[-1, :, :]) + + v2 = self.val_proj2(hs_d) # [S, B, latent_k_dim/2] + value = ( + torch.cat([v1, v2], dim=-1).contiguous().view(*hs.shape[:2], self.num_kv_heads, self.head_dim) + ) # [S, B, kh, dh] + + # L2-normalize per head, then scale + query_norm = query.norm(p=2, dim=-1, keepdim=True) + key_norm = key.norm(p=2, dim=-1, keepdim=True) + + key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1) + query = query * (self.sqrt_head_dim / query_norm) + + # Flatten head axis, then return to HF layout [B, S, ...] + query = query.view(*query.shape[:2], self.num_q_heads * self.head_dim).transpose(0, 1).contiguous() + key = key.view(*key.shape[:2], self.num_kv_heads * self.head_dim).transpose(0, 1).contiguous() + value = value.view(*value.shape[:2], self.num_kv_heads * self.head_dim).transpose(0, 1).contiguous() + return query, key, value + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Applies Rotary Position Embedding to the query and key tensors. This version + correctly handles cases where the rotary dimension is smaller than the head dimension. + """ + # The rotary dimension is determined by the size of the cos/sin tensors. + rotary_dim = cos.shape[-1] + + # Unsqueeze cos and sin for broadcasting + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Slice the query and key into the part to be rotated and the part to be + # passed through. + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply RoPE to the first part + q_rot = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate the rotated and passed-through parts back together + q_embed = torch.cat((q_rot, q_pass), dim=-1) + k_embed = torch.cat((k_rot, k_pass), dim=-1) + + return q_embed, k_embed + + +class ZayaAttention(nn.Module): + def __init__(self, config: ZayaConfig, layer_n): + super().__init__() + self.debug_level = 3 + self.config = config + self.layer_n = layer_n + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.o_proj = nn.Linear( + (self.num_heads // 2) * self.head_dim, + self.hidden_size, + bias=self.config.attention_bias, + ) # hardcoded query compression for now + self.qkv = CCA( + config=self.config, + cca_num_q_heads=self.config.cca_num_q_heads[layer_n], + cca_num_kv_heads=self.config.num_query_groups_list[layer_n], + cca_num_heads=self.num_heads, + hidden_size=self.hidden_size, + layer_number=layer_n, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + batch_size, seq_length, _ = hidden_states.shape + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, cca_mask) + query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads // 2, self.head_dim) + key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups // 2) + value_states = repeat_kv(value_states, self.num_key_value_groups // 2) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != ( + batch_size, + self.num_heads // 2, + seq_length, + self.head_dim, + ): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, seq_length, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_length, self.hidden_size // 2) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_values + + +class ZayaSdpaAttention(ZayaAttention): + """ + Zaya attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `ZayaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from ZayaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. + # `model.config.attn_implementation = "manual"` once this is + # implemented. + logger.warning_once( + "ZayaModel is using zayaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + ) + + batch_size, seq_length, _ = hidden_states.shape + + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, cca_mask) + + query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads // 2, self.head_dim) + key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups // 2) + value_states = repeat_kv(value_states, self.num_key_value_groups // 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with + # AttentionMaskConverter.to_causal_4d that does not create a causal + # mask in case q_len == 1. + is_causal = bool(causal_mask is None and seq_length > 1) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_length, self.hidden_size // 2) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_values + + +class ZayaFlashAttention2(ZayaAttention): + """ + Zaya flash attention module. This module inherits from `ZayaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ): + if output_attentions: + # TODO: Improve this warning with e.g. + # `model.config.attn_implementation = "manual"` once this is + # implemented. + logger.warning_once( + "ZayaModel is using zayaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + ) + + batch_size, seq_length, _ = hidden_states.shape + + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, cca_mask) + + query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads // 2, self.head_dim) + key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n, cache_kwargs) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as + # expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.o_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + seq_length, + position_ids=position_ids, + dropout=self.attention_dropout if self.training else 0.0, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size // 2).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_values + + +Zaya_ATTENTION_CLASSES = { + "eager": ZayaAttention, + "flash_attention_2": ZayaFlashAttention2, + "sdpa": ZayaSdpaAttention, +} + + +class ZayaDecoderATTLayer(nn.Module): + def __init__(self, config: ZayaConfig, layer_n: int, training: bool): + super().__init__() + self.debug_level = 2 + self.config = config + self.layer_n = layer_n + self.training = self.training + self.self_attn = Zaya_ATTENTION_CLASSES[config._attn_implementation](config, layer_n) + + self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + + if self.config.scale_residual_merge: + self.res_scale = ResidualScaling(config, layer_n) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + prev_router_hidden_states: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + residual (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + past_key_values (`tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`torch.FloatTensor`): Positional embedding used. + prev_router_hidden_states (`torch.FloatTensor`): Activations from the previous router to do DWA. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + if self.config.scale_residual_merge: + residual, hidden_states = self.res_scale(residual, hidden_states) + if residual is None: + residual = hidden_states.to(torch.float32) if (self.config.residual_in_fp32) else hidden_states + else: + residual = hidden_states + residual + hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype)) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + cca_mask=cca_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs, residual, prev_router_hidden_states + + +class ResidualScaling(nn.Module): + def __init__(self, config, layer_n): + super().__init__() + self.config = config + self.not_first_layer = layer_n != 0 + self.hidden_states_scale = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + self.hidden_states_bias = torch.nn.Parameter(torch.zeros(self.config.hidden_size)) + + if self.not_first_layer: + self.residual_scale = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + self.residual_bias = torch.nn.Parameter(torch.zeros(self.config.hidden_size)) + + def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor): + """Apply residual scaling with optional activation checkpointing.""" + hidden_states = (hidden_states + self.hidden_states_bias.expand(1, 1, -1)) * self.hidden_states_scale.expand( + 1, 1, -1 + ) + if self.not_first_layer: + residual = (residual + self.residual_bias.expand(1, 1, -1)) * self.residual_scale.expand(1, 1, -1) + return residual, hidden_states + + +class ZayaRouter(nn.Module): + """ + Key features: + - Down-projection to `mlp_expansion` then RMSNorm. + - Optional EDA (depth-wise averaging) via `router_states_scale` and prior `router_states`. + - Three-layer MLP with GELU producing `num_experts` logits per token. + - Top-k expert selection with balancing biases and MOD (skip expert) handling. + + Returns (route_prob, expert_choice_t, router_hidden_states_next) where: + - route_prob: (batch*seq, topk) gathered probabilities of chosen experts + - expert_choice_t: (batch*seq, topk) chosen expert indices (with MOD post-processing) + - router_hidden_states_next: the pre-norm router hidden states (B, S, mlp_expansion), + for feeding forward to the MoE layer. + """ + + def __init__( + self, + config, + layer_n: int, + num_moe_experts: int, + moe_router_topk: int, + mlp_expansion: int, + hidden_size: Optional[int] = None, + layer_number: Optional[int] = None, + ) -> None: + super().__init__() + + self.debug_level = 4 + # ---- Config / shape ---- + self.config = config + self.layer_n = layer_n + self.hidden_size = int(hidden_size or getattr(config, "hidden_size")) + self.layer_number = layer_number if layer_number is not None else 0 + + # MOD + self.use_mod = bool(getattr(config, "zaya_use_mod", False)) + + # Expert counts (extra 'skip' expert if MOD) + self.num_experts = (num_moe_experts + 1) if self.use_mod else num_moe_experts + self.topk = int(moe_router_topk) + + # Router hidden dim + self.mlp_expansion = int(mlp_expansion) + + # ---- Layers ---- + self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True) + + # EDA (depth-wise averaging) + zaya_first_layer = 1 + use_eda_cfg = bool(getattr(config, "zaya_use_eda", False)) + self.use_eda = use_eda_cfg and (zaya_first_layer is not None) and (self.layer_number != zaya_first_layer) + + ln_eps = float(getattr(config, "layernorm_epsilon", 1e-6)) + self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps) + if self.use_eda: + # eda + self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) + + # routermlp + D = self.mlp_expansion + E = self.num_experts + self.non_linearity = nn.GELU() + self.router_mlp = nn.Sequential( + nn.Linear(D, D, bias=True), + self.non_linearity, + nn.Linear(D, D, bias=True), + self.non_linearity, + nn.Linear(D, E, bias=False), + ) + + # Balancing biases + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + if self.use_mod: + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, # (B, S, H) + router_states: Optional[torch.Tensor] = None, # (B, S, D) previous router states for EDA + eos_mask: Optional[torch.Tensor] = None, # unused here; kept for API compatibility + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute per-token expert probabilities and choose top-k experts. + + Args: + hidden_states: (batch, seq, hidden_size) + router_states: (batch, seq, mlp_expansion) from prior step/layer (for EDA). Optional. + eos_mask: kept for API compatibility; not used. + + Returns: + route_prob: (batch*seq, topk) + expert_choice_t: (batch*seq, topk) int64 + router_hidden_states_next: (batch, seq, mlp_expansion) + """ + B, S, _ = hidden_states.shape + + # eda + hs = self.down_proj(hidden_states) + + if self.use_eda and (router_states is not None): + hs = hs + router_states * self.router_states_scale + + # Stash the pre-norm states for the caller + router_hidden_states_next = hs[:, -S:].clone() + + # 2) RMSNorm eda + hs_norm = self.rmsnorm_eda(hs) + + # 3) Expert probability distribution + logits = self.router_mlp(hs_norm) + expert_prob = torch.softmax(logits, dim=-1) + + # 4) expert choice with balancing biases (biases affect choice only, + # not the probabilities) + biased = expert_prob.detach().to(torch.float32) + self.balancing_biases + _, expert_choice_t = torch.topk(biased, self.topk, dim=-1) # (B, S, topk) + + # 5) If MOD and topk>1, once skip expert is selected, force all + # subsequent choices to skip as well, but this never happens since we + # use topk=1 + if (self.topk > 1) and self.use_mod: + skip_idx = self.num_experts - 1 + n_mask = expert_choice_t == skip_idx + cumsum_mask = torch.cumsum(n_mask, dim=-1) + expert_choice_t = expert_choice_t.masked_fill(cumsum_mask > 0, skip_idx) + + # Gather the probabilities for the selected experts + route_prob = torch.gather(expert_prob, dim=2, index=expert_choice_t) + expert_choice_flat = expert_choice_t.reshape(-1, self.topk) + route_prob_flat = route_prob.reshape(-1, self.topk) + + return route_prob_flat, expert_choice_flat, router_hidden_states_next + + +def bias_swiglu_impl(input, bias, fp8_input_store=False): + ori_shape = input.shape + assert len(ori_shape) in [2, 3] + input = input.view(-1, ori_shape[-1]) + if bias is not None: + output = BiasSwiGLUFunction.apply(input, bias, fp8_input_store) + else: + output = SwiGLUFunction.apply(input, fp8_input_store) + + return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) + + +class MLP(nn.Module): + """ + MLP will take the input with h hidden state, project it to another + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + + Returns an output and a bias to be added to the output. + If config.add_bias_linear is False, the bias returned is None. + + We use the following notation: + h: hidden size + p: number of tensor model parallel partitions + b: batch size + s: sequence length + """ + + def __init__(self, config, ffn_hidden_size): + super().__init__() + self.config = config + + # Double the output width with gated linear unit, see + # https://arxiv.org/pdf/2002.05202.pdf + if self.config.gated_linear_unit: + ffn_hidden_size_out = ffn_hidden_size // 2 + else: + ffn_hidden_size_out = ffn_hidden_size + + # Set the activation function. + if self.config.activation_func == "swiglu": + self.activation_func = F.silu + else: + self.activation_func = F.gelu + + self.linear_fc1 = nn.Linear( + in_features=self.config.hidden_size, + out_features=ffn_hidden_size, + bias=self.config.add_bias_linear, + ) + self.linear_fc2 = nn.Linear( + in_features=ffn_hidden_size_out, + out_features=self.config.hidden_size, + bias=self.config.add_bias_linear, + ) + + def forward(self, hidden_states): + # [s, b, 4 * h/p] + if self.config.add_bias_linear: + intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) + else: + intermediate_parallel = self.linear_fc1(hidden_states) + bias_parallel = None + + if self.config.bias_activation_fusion: + if self.activation_func == F.silu and self.config.gated_linear_unit: + intermediate_parallel = bias_swiglu_impl( + intermediate_parallel, + bias_parallel, + self.config.activation_func_fp8_input_store, + ) + else: + raise ValueError("Only support fusion of swiglu") + else: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + if self.config.gated_linear_unit: + + def glu(x): + x = torch.chunk(x, 2, dim=-1) + return self.config.activation_func(x[0]) * x[1] + + intermediate_parallel = glu(intermediate_parallel) + else: + intermediate_parallel = self.activation_func(intermediate_parallel) + + # [s, b, h] + if self.config.add_bias_linear: + output, output_bias = self.linear_fc2(intermediate_parallel) + else: + output = self.linear_fc2(intermediate_parallel) + output_bias = None + + return output, output_bias + + +class SequentialMLP(nn.Module): + """An implementation of the Experts layer using a sequence of MLP layers. + This class executes each expert sequentially. + """ + + def __init__(self, num_local_experts: int, config, ffn_hidden_size: int): + super().__init__() + self.config = config + self.add_bias = config.add_bias_linear + self.num_local_experts = num_local_experts + self.local_experts = torch.nn.ModuleList() + + for _ in range(self.num_local_experts): + expert = MLP(config=self.config, ffn_hidden_size=ffn_hidden_size) + self.local_experts.append(expert) + + def forward( + self, + permuted_local_hidden_states: torch.Tensor, + tokens_per_expert: torch.Tensor, + ): + """Forward step of the SequentialMLP.""" + if self.num_local_experts == 1: + output, output_bias = self.local_experts[0](permuted_local_hidden_states) + return output, output_bias + else: + tokens_per_expert = tokens_per_expert.tolist() + tokens_list = torch.split(permuted_local_hidden_states, tokens_per_expert) + + output_local_list = [] + output_bias_list = [] + + for expert, tokens in zip(self.local_experts, tokens_list): + output, output_bias = expert(tokens) + output_local_list.append(output) + if self.add_bias: + output_bias_list.append(output_bias.expand_as(output)) + + output_local = torch.cat(output_local_list, dim=0) + if self.add_bias: + output_bias_local = torch.cat(output_bias_list, dim=0) + else: + output_bias_local = None + + return output_local, output_bias_local + + +class ZayaBlock(nn.Module): + def __init__( + self, + config, + layer_idx: int, + mlp_expansion: int, + ffn_hidden_size: int, + first_mlp_layer: bool, + layer_n: int, + training: bool, + ): + super().__init__() + self.debug_level = 3 + self.config = config + self.layer_n = layer_n + self.training = training + self.hidden_dim = config.hidden_size + self.num_moe_experts = layer_idx + self.mlp_expansion = mlp_expansion + self.first_mlp_layer = first_mlp_layer + self.router = ZayaRouter( + config=self.config, + layer_n=layer_n, + num_moe_experts=self.num_moe_experts, + moe_router_topk=getattr(self.config, "moe_router_topk", 1), + mlp_expansion=mlp_expansion, + hidden_size=self.hidden_dim, + layer_number=layer_n, + ) + self.experts = SequentialMLP(self.num_moe_experts, self.config, ffn_hidden_size=ffn_hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: Optional[torch.Tensor] = None, + past_router_states: Optional[torch.Tensor] = None, + use_cache=False, + cca_mask: Optional[torch.Tensor] = None, + ): + route_prob, expert_choice, prev_router_hidden_states = self.router( + hidden_states, router_states=prev_router_hidden_states + ) + probs = route_prob + indices = expert_choice + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + indices_flat = indices.view(batch_size * seq_length) + sorted_indices, sort_order = torch.sort(indices_flat) + tokens_per_expert = torch.bincount(sorted_indices, minlength=self.router.num_experts) + sorted_hidden_states = hidden_states_flat[sort_order] + original_order = torch.argsort(sort_order) + + if self.config.zaya_use_mod: + expert_output, mlp_bias = self.experts( + sorted_hidden_states[: sum(tokens_per_expert[:-1])], + tokens_per_expert[:-1], + ) + expert_output = torch.cat( + [expert_output, sorted_hidden_states[sum(tokens_per_expert[:-1]) :]], + dim=0, + ) + if mlp_bias is not None: + mlp_bias = torch.cat( + [ + mlp_bias, + torch.zeros_like(sorted_hidden_states[sum(tokens_per_expert[:-1]) :]), + ], + dim=0, + ) + else: + expert_output, mlp_bias = self.experts(sorted_hidden_states, tokens_per_expert) + + expert_output = expert_output[original_order] + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + # print(probs.shape,expert_output.shape) + probs = probs.view(batch_size, seq_length) + expert_output = expert_output * probs.unsqueeze(-1) + + if mlp_bias is not None: + mlp_bias = mlp_bias[original_order] + mlp_bias = mlp_bias.view(batch_size, seq_length, emb_dim) + + return expert_output, mlp_bias, prev_router_hidden_states + + +class ZayaDecoderMLPLayer(nn.Module): + def __init__( + self, + config: ZayaConfig, + layer_idx: int, + mlp_expansion: int, + ffn_hidden_size: int, + first_mlp_layer: bool, + layer_n: int, + training: bool, + ): + super().__init__() + self.debug_level = 2 + self.config = config + self.layer_n = layer_n + self.training = training + self.zaya_block = ZayaBlock( + config, + layer_idx, + mlp_expansion, + ffn_hidden_size, + first_mlp_layer, + layer_n, + self.training, + ) + self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + + if self.config.scale_residual_merge: + self.res_scale = ResidualScaling(config, layer_n) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + prev_router_hidden_states: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + residual (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + past_key_values (`tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`torch.FloatTensor`): Positional embedding used. + prev_router_hidden_states (`torch.FloatTensor`): Activations from the previous router to do DWA. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + if self.config.scale_residual_merge: + residual, hidden_states = self.res_scale(residual, hidden_states) + if residual is None: + residual = hidden_states.to(torch.float32) if (self.config.residual_in_fp32) else hidden_states + else: + residual = hidden_states + residual + hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype)) + + if self.config.add_bias_linear: + hidden_states, bias_states, prev_router_hidden_states = self.zaya_block( + hidden_states, + prev_router_hidden_states, + past_key_values, + use_cache, + cca_mask, + ) + hidden_states = hidden_states + bias_states + else: + hidden_states, _, prev_router_hidden_states = self.zaya_block( + hidden_states, + prev_router_hidden_states, + past_key_values, + use_cache, + cca_mask, + ) + + outputs = (hidden_states,) + + return outputs, residual, prev_router_hidden_states + + +Zaya_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`ZayaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Zaya Model outputting raw hidden-states without any specific head on top.", + Zaya_START_DOCSTRING, +) +class ZayaPreTrainedModel(PreTrainedModel): + config_class = ZayaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ZayaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = False + # MoE models don't work with torch.compile (`torch.where(condition)` not + # supported) + _supports_static_cache = False + + +Zaya_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Zaya Model outputting raw hidden-states without any specific head on top.", + Zaya_START_DOCSTRING, +) +class ZayaModel(ZayaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer can be an attention layer ZayaDecoderATTLayer or an MLP layer ZayaDecoderMLPLayer. + Args: + config: ZayaConfig + """ + + def __init__(self, config: ZayaConfig): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self._attn_implementation = config._attn_implementation + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.padding_idx) + self.layers = [] + first_mlp_layer = True + + for layer_n in range(len(config.zaya_layers)): + if isinstance(config.zaya_layers[layer_n], int): + self.layers.append( + ZayaDecoderMLPLayer( + config, + config.zaya_layers[layer_n], + config.zaya_mlp_expansion[layer_n], + config.ffn_hidden_size_list[layer_n], + first_mlp_layer, + layer_n, + self.training, + ) + ) + first_mlp_layer = False + else: + self.layers.append(ZayaDecoderATTLayer(config, layer_n, self.training)) + self.layers = nn.ModuleList(self.layers) + + self.gradient_checkpointing = False + + if self.config.scale_residual_merge: + self.res_scale = ResidualScaling(config, len(config.zaya_layers)) + + self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.rotary_emb = ZayaRotaryEmbedding(config=config) + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(Zaya_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, MoeModelOutputWithPast]: + _, seq_length = input_ids.shape + + if attention_mask is not None: + cca_mask = attention_mask.clone() + else: + cca_mask = None + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = ZayaDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + residual = None + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + prev_router_hidden_states = None + + for layer_n, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs, residual, prev_router_hidden_states = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + residual, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + prev_router_hidden_states, + cca_mask, + ) + else: + layer_outputs, residual, prev_router_hidden_states = decoder_layer( + hidden_states, + residual, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + prev_router_hidden_states=prev_router_hidden_states, + cca_mask=cca_mask, + ) + + hidden_states = layer_outputs[0] + + if isinstance(decoder_layer, ZayaDecoderATTLayer): + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.config.scale_residual_merge: + residual, hidden_states = self.res_scale(residual, hidden_states) + + if residual is None: + residual = hidden_states.to(torch.float32) if (self.config.residual_in_fp32) else hidden_states + else: + residual = hidden_states + residual + + hidden_states = self.final_norm(residual.to(dtype=self.final_norm.weight.dtype)) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from + # transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask + # with Phi3->Zaya + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Zaya. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method + # calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal + # mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from + # transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position + # with Mistral->Zaya + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: ZayaConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`ZayaConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted + # form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was + # trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.config = config + self.tie_word_embeddings = self.config.tie_word_embeddings # so the linter stops complaining + self.model = ZayaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) + self.post_init() + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(Zaya_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, ZayaForCausalLM + >>> model = ZayaForCausalLM.from_pretrained("Zyphra/Zaya-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zaya-8B") + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + if ( + use_cache + and self.config.rope_scaling + and "original_max_position_embeddings" in self.config.rope_scaling + and cache_position is not None + and cache_position[0] == self.config.rope_scaling["original_max_position_embeddings"] + ): + logger.warning( + f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.rope_scaling['original_max_position_embeddings']}th token, as the KV cache needs to be recomputed." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from + # transformers.models.phi3.modeling_phi3.Phi3ForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwitten -- has a unique cache type, `ZayaDynamicCache` + if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache): + raise ValueError( + f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}." + ) + empty_past_kv = past_key_values is None + + # Omit tokens covered by past_key_values + if not empty_past_kv: + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if inputs_embeds is not None or ( # Exception 1 + is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] + ): # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = ZayaDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st + # generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": logits_to_keep, + "cache_position": cache_position, + } + ) + + return model_inputs + + def _prepare_cache_for_generation( + self, + generation_config, + model_kwargs: dict, + generation_mode, + batch_size: int, + max_cache_length: int, + ): + if "past_key_values" not in model_kwargs: + model_kwargs["past_key_values"] = ZayaDynamicCache( + self.config, batch_size, dtype=self.dtype, device=self.device + ) + generation_config.cache_implementation = None + return super()._prepare_cache_for_generation( + generation_config=generation_config, + model_kwargs=model_kwargs, + generation_mode=generation_mode, + batch_size=batch_size, + max_cache_length=max_cache_length, + ) + + +__all__ = ["ZayaPreTrainedModel", "ZayaModel", "ZayaForCausalLM"] diff --git a/src/transformers/models/zaya/modular_zaya.py b/src/transformers/models/zaya/modular_zaya.py new file mode 100755 index 000000000000..b3a1c07a0c3d --- /dev/null +++ b/src/transformers/models/zaya/modular_zaya.py @@ -0,0 +1,2022 @@ +# coding=utf-8 +# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Zaya model.""" + +import math +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from packaging.version import Version as PkgVersion +from torch import nn + +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.import_utils import is_torch_fx_available +from ..glm4.modeling_glm4 import Glm4RotaryEmbedding +from ..llama.modeling_llama import LlamaRMSNorm, repeat_kv +from .configuration_zaya import ZayaConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear +# as a node in the graph. +if is_torch_fx_available(): + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ZayaConfig" + + +try: + _torch_version = PkgVersion(torch.__version__) +except BaseException: + _torch_version = PkgVersion("0.0.0") + + +def get_torch_version(): + """Get torch version from __version__.""" + + global _torch_version + return _torch_version + + +def is_torch_min_version(version, check_equality=True): + """Check if minimum version of `torch` is installed.""" + if check_equality: + return get_torch_version() >= PkgVersion(version) + return get_torch_version() > PkgVersion(version) + + +jit_fuser = torch.jit.script +if is_torch_min_version("2.2.0a0"): + jit_fuser = torch.compile + + +def swiglu(y): + y_1, y_2 = torch.chunk(y, 2, -1) + return F.silu(y_1) * y_2 + + +@jit_fuser +def bias_swiglu(y, bias): + y = y + bias + return swiglu(y) + + +# @jit_fuser + + +def swiglu_back(g, y): + y_1, y_2 = torch.chunk(y, 2, -1) + return torch.cat( + ( + g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, + g * F.silu(y_1), + ), + -1, + ) + + +@jit_fuser +def bias_swiglu_back(g, y, bias): + y = y + bias + return swiglu_back(g, y) + + +class BiasSwiGLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias, fp8_input_store): + input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + ctx.save_for_backward(input_for_backward, bias) + ctx.ori_input_dtype = input.dtype + ctx.fp8_input_store = fp8_input_store + return bias_swiglu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + tmp = bias_swiglu_back(grad_output, input, bias) + return tmp, tmp, None + + +class SwiGLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, fp8_input_store): + input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + ctx.save_for_backward(input_for_backward) + ctx.ori_input_dtype = input.dtype + ctx.fp8_input_store = fp8_input_store + return swiglu(input) + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_tensors[0] + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + tmp = swiglu_back(grad_output, input) + return tmp, None + + +def bias_swiglu_impl(input, bias, fp8_input_store=False): + ori_shape = input.shape + assert len(ori_shape) in [2, 3] + input = input.view(-1, ori_shape[-1]) + if bias is not None: + output = BiasSwiGLUFunction.apply(input, bias, fp8_input_store) + else: + output = SwiGLUFunction.apply(input, fp8_input_store) + + return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Applies Rotary Position Embedding to the query and key tensors. This version + correctly handles cases where the rotary dimension is smaller than the head dimension. + """ + # The rotary dimension is determined by the size of the cos/sin tensors. + rotary_dim = cos.shape[-1] + + # Unsqueeze cos and sin for broadcasting + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Slice the query and key into the part to be rotated and the part to be + # passed through. + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply RoPE to the first part + q_rot = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate the rotated and passed-through parts back together + q_embed = torch.cat((q_rot, q_pass), dim=-1) + k_embed = torch.cat((k_rot, k_pass), dim=-1) + + return q_embed, k_embed + + +class ZayaRotaryEmbedding(Glm4RotaryEmbedding): + pass + + +class ZayaRMSNorm(LlamaRMSNorm): + pass + + +class ZayaDynamicCache(DynamicCache): + """ + Cache that includes both the KV cache and the CCA cache. + """ + + def __init__( + self, + config: ZayaConfig, + batch_size: int, + conv_kernel_size: int = 2, + dtype: torch.dtype = torch.float16, + device: Optional[str] = None, + ): + super().__init__() + self.config = config + self.batch_size = batch_size + self.dtype = dtype + self.device = device + num_k_heads = config.num_query_groups_list[0] + num_q_heads = config.cca_num_q_heads[0] + head_dim = config.hidden_size // config.num_attention_heads + self.conv_kernel_size = conv_kernel_size + self.num_layers = len(config.zaya_layers) + self.latent_k_dim = num_k_heads * head_dim + self.latent_q_dim = num_q_heads * head_dim + self.in_out_ch = self.latent_k_dim + self.latent_q_dim + self.has_previous_state = False + + self.conv_states = torch.zeros( + self.num_layers, + batch_size, + self.in_out_ch, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + + self.prev_hs = torch.zeros( + self.num_layers, + batch_size, + config.hidden_size, + device=device, + dtype=dtype, + ) + + def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor: + if not self.has_previous_state: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.prev_hs.zero_() + + +class CCA(nn.Module): + def __init__( + self, + config: ZayaConfig, + cca_num_kv_heads: int = 2, + cca_num_q_heads: int = 8, + cca_num_heads: int = 16, + hidden_size: Optional[int] = None, + cca_time0: int = 2, + cca_time1: int = 2, + layer_number: int = 0, + ): + super().__init__() + self.debug_level = 4 + self.config = config + self.layer_number = layer_number + + # Use the model's true hidden size unless explicitly overridden. + self.hidden_size = int(hidden_size or config.hidden_size) + + self.cca_time0 = cca_time0 + self.cca_time1 = cca_time1 + self.padding0 = cca_time0 - 1 + self.padding1 = cca_time1 - 1 + self.total_padding = self.padding0 + self.padding1 + + self.num_kv_heads = int(cca_num_kv_heads) + self.num_q_heads = int(cca_num_q_heads) + self.num_heads = int(cca_num_heads) + + # Geometry + self.head_dim = self.hidden_size // self.num_heads + self.latent_k_dim = self.num_kv_heads * self.head_dim + self.latent_q_dim = self.num_q_heads * self.head_dim + self.sqrt_head_dim = float(np.sqrt(self.head_dim)) + self.gqa_groups = self.num_q_heads // self.num_kv_heads + assert self.num_q_heads % self.num_kv_heads == 0, "q_heads must be a multiple of k_heads" + assert (self.latent_k_dim + self.latent_q_dim) == (self.num_kv_heads + self.num_q_heads) * self.head_dim + + # Projections + self.linear_q = nn.Linear(self.hidden_size, self.latent_q_dim, bias=self.config.attention_bias) + self.linear_k = nn.Linear(self.hidden_size, self.latent_k_dim, bias=self.config.attention_bias) + self.val_proj1 = nn.Linear(self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias) + self.val_proj2 = nn.Linear(self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias) + + # Depthwise + grouped conv along sequence + in_out_ch = self.latent_k_dim + self.latent_q_dim + self.conv_qk = nn.Sequential( + nn.Conv1d( + in_channels=in_out_ch, + out_channels=in_out_ch, + kernel_size=self.cca_time0, + groups=in_out_ch, + padding=0, + stride=1, + ), + nn.Conv1d( + in_channels=in_out_ch, + out_channels=in_out_ch, + kernel_size=self.cca_time1, + groups=(self.num_kv_heads + self.num_q_heads), + padding=0, + stride=1, + ), + ) + + # Per-k head temperature + self.temp = nn.Parameter(torch.zeros(self.num_kv_heads)) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_values: Optional[ZayaDynamicCache], + cca_mask, + ): + """ + hidden_states: [B, S, E] (HF layout) + returns: + query: [B, S, num_q_heads*head_dim] + key : [B, S, num_k_heads*head_dim] + value: [B, S, num_k_heads*head_dim] + """ + if cca_mask is not None and hidden_states.shape[1] > 1: + # Only applying in prefill + dtype = hidden_states.dtype + hidden_states = (hidden_states * cca_mask[:, :, None]).to(dtype) + + # ---- Switch to [S, B, H] ---- + hs = hidden_states.transpose(0, 1).contiguous() # [S, B, H] + # Time-shifted stream for v2 (pad one at the front along sequence) + hs_d = F.pad(hs[:-1], pad=(0, 0, 0, 0, 1, 0)) # [S, B, H] + + # Q/K in the full space + q = self.linear_q(hs) # [S, B, latent_q_dim] + k = self.linear_k(hs) # [S, B, latent_k_dim] + qk_packed0 = torch.cat([q, k], dim=-1) # [S, B, latent_q + latent_k] + + # Pre-mean tensors in head form (for "qk_mean_{q,k}" calc) + query_pre = qk_packed0[..., : self.latent_q_dim].view( + *qk_packed0.shape[:2], self.num_q_heads, self.head_dim + ) # [S, B, qh, dh] + + key_pre = qk_packed0[..., self.latent_q_dim :].view( + *qk_packed0.shape[:2], self.num_kv_heads, self.head_dim + ) # [S, B, kh, dh] + key_pre = ( + key_pre.unsqueeze(-2) + .repeat(1, 1, 1, self.gqa_groups, 1) + .view(*qk_packed0.shape[:2], self.num_q_heads, self.head_dim) + ) # [S, B, qh, dh] + + # Means for residual mixing + qk_mean_q = (query_pre + key_pre) / 2 + qk_mean_k = qk_mean_q.view(*qk_mean_q.shape[:2], self.num_kv_heads, self.gqa_groups, -1).mean(dim=-2) + + if past_key_values is not None: + if past_key_values.has_previous_state: + # Generation + qk_packed0 = qk_packed0.transpose(0, 1) # [B, 1, H] + qk_packed0_cached = past_key_values.conv_states[self.layer_number] # [B, H, 2] + qk_packed0_cat = torch.cat([qk_packed0_cached, qk_packed0.transpose(1, 2)], dim=-1) # [B, H, 3] + qk_packed3 = self.conv_qk(qk_packed0_cat).permute(2, 0, 1) # [S, B, E] + past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=qk_packed0) # [B, H, 2] + + else: + # Prefill + qk_packed0_transposed = qk_packed0.permute(1, 2, 0) # [S, B, H] -> [B, H, S] + conv_states = nn.functional.pad( + qk_packed0_transposed, + ( + past_key_values.conv_kernel_size - qk_packed0_transposed.shape[-1], + 0, + ), + ) + past_key_values.update_conv_state(layer_idx=self.layer_number, new_conv_state=conv_states) + # Convs over sequence: [S, B, E] -> [B, E, S] -> pad -> conv -> + # [S, B, E] + qk_packed1 = qk_packed0.permute(1, 2, 0) # [B, E, S] + qk_packed2 = F.pad(qk_packed1, (self.total_padding, 0)) + qk_packed3 = self.conv_qk(qk_packed2).permute(2, 0, 1) # [S, B, E] + + else: + # Convs over sequence: [S, B, E] -> [B, E, S] -> pad -> conv -> [S, + # B, E] + qk_packed1 = qk_packed0.permute(1, 2, 0) # [B, E, S] + qk_packed2 = F.pad(qk_packed1, (self.total_padding, 0)) + qk_packed3 = self.conv_qk(qk_packed2).permute(2, 0, 1) # [S, B, E] + + # Build queries/keys from conv output + means + query = ( + qk_packed3[..., : self.latent_q_dim].view(*qk_packed3.shape[:2], self.num_q_heads, self.head_dim) + + qk_mean_q + ) # [S, B, qh, dh] + + key = ( + qk_packed3[..., self.latent_q_dim :].view(*qk_packed3.shape[:2], self.num_kv_heads, self.head_dim) + + qk_mean_k + ) # [S, B, kh, dh] + + # Values from the two time streams + v1 = self.val_proj1(hs) # [S, B, latent_k_dim/2] + if past_key_values is not None: + if past_key_values.has_previous_state: + # Generation + # [B, H] + hs_d = past_key_values.prev_hs[self.layer_number].clone() + hs_d = hs_d.unsqueeze(0) # [1, B, H] + past_key_values.prev_hs[self.layer_number].copy_(hs[-1, :, :]) + + v2 = self.val_proj2(hs_d) # [S, B, latent_k_dim/2] + value = ( + torch.cat([v1, v2], dim=-1).contiguous().view(*hs.shape[:2], self.num_kv_heads, self.head_dim) + ) # [S, B, kh, dh] + + # L2-normalize per head, then scale + query_norm = query.norm(p=2, dim=-1, keepdim=True) + key_norm = key.norm(p=2, dim=-1, keepdim=True) + + key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1) + query = query * (self.sqrt_head_dim / query_norm) + + # Flatten head axis, then return to HF layout [B, S, ...] + query = query.view(*query.shape[:2], self.num_q_heads * self.head_dim).transpose(0, 1).contiguous() + key = key.view(*key.shape[:2], self.num_kv_heads * self.head_dim).transpose(0, 1).contiguous() + value = value.view(*value.shape[:2], self.num_kv_heads * self.head_dim).transpose(0, 1).contiguous() + return query, key, value + + +class ZayaAttention(nn.Module): + def __init__(self, config: ZayaConfig, layer_n): + super().__init__() + self.debug_level = 3 + self.config = config + self.layer_n = layer_n + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.o_proj = nn.Linear( + (self.num_heads // 2) * self.head_dim, + self.hidden_size, + bias=self.config.attention_bias, + ) # hardcoded query compression for now + self.qkv = CCA( + config=self.config, + cca_num_q_heads=self.config.cca_num_q_heads[layer_n], + cca_num_kv_heads=self.config.num_query_groups_list[layer_n], + cca_num_heads=self.num_heads, + hidden_size=self.hidden_size, + layer_number=layer_n, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + batch_size, seq_length, _ = hidden_states.shape + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, cca_mask) + query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads // 2, self.head_dim) + key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups // 2) + value_states = repeat_kv(value_states, self.num_key_value_groups // 2) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != ( + batch_size, + self.num_heads // 2, + seq_length, + self.head_dim, + ): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, seq_length, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_length, self.hidden_size // 2) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_values + + +class ZayaSdpaAttention(ZayaAttention): + """ + Zaya attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `ZayaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from ZayaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. + # `model.config.attn_implementation = "manual"` once this is + # implemented. + logger.warning_once( + "ZayaModel is using zayaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + ) + + batch_size, seq_length, _ = hidden_states.shape + + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, cca_mask) + + query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads // 2, self.head_dim) + key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups // 2) + value_states = repeat_kv(value_states, self.num_key_value_groups // 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with + # AttentionMaskConverter.to_causal_4d that does not create a causal + # mask in case q_len == 1. + is_causal = bool(causal_mask is None and seq_length > 1) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_length, self.hidden_size // 2) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_values + + +class ZayaFlashAttention2(ZayaAttention): + """ + Zaya flash attention module. This module inherits from `ZayaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ): + if output_attentions: + # TODO: Improve this warning with e.g. + # `model.config.attn_implementation = "manual"` once this is + # implemented. + logger.warning_once( + "ZayaModel is using zayaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + ) + + batch_size, seq_length, _ = hidden_states.shape + + query_states, key_states, value_states = self.qkv(hidden_states, past_key_values, cca_mask) + + query_states = query_states.view(batch_size, seq_length, self.config.num_attention_heads // 2, self.head_dim) + key_states = key_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(batch_size, seq_length, self.config.num_key_value_heads, self.head_dim) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_n, cache_kwargs) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as + # expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.o_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + seq_length, + position_ids=position_ids, + dropout=self.attention_dropout if self.training else 0.0, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size // 2).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_values + + +Zaya_ATTENTION_CLASSES = { + "eager": ZayaAttention, + "flash_attention_2": ZayaFlashAttention2, + "sdpa": ZayaSdpaAttention, +} + + +class ZayaDecoderATTLayer(nn.Module): + def __init__(self, config: ZayaConfig, layer_n: int, training: bool): + super().__init__() + self.debug_level = 2 + self.config = config + self.layer_n = layer_n + self.training = self.training + self.self_attn = Zaya_ATTENTION_CLASSES[config._attn_implementation](config, layer_n) + + self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + + if self.config.scale_residual_merge: + self.res_scale = ResidualScaling(config, layer_n) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + prev_router_hidden_states: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + residual (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + past_key_values (`tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`torch.FloatTensor`): Positional embedding used. + prev_router_hidden_states (`torch.FloatTensor`): Activations from the previous router to do DWA. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + if self.config.scale_residual_merge: + residual, hidden_states = self.res_scale(residual, hidden_states) + if residual is None: + residual = hidden_states.to(torch.float32) if (self.config.residual_in_fp32) else hidden_states + else: + residual = hidden_states + residual + hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype)) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + cca_mask=cca_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs, residual, prev_router_hidden_states + + +class ResidualScaling(nn.Module): + def __init__(self, config, layer_n): + super().__init__() + self.config = config + self.not_first_layer = layer_n != 0 + self.hidden_states_scale = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + self.hidden_states_bias = torch.nn.Parameter(torch.zeros(self.config.hidden_size)) + + if self.not_first_layer: + self.residual_scale = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + self.residual_bias = torch.nn.Parameter(torch.zeros(self.config.hidden_size)) + + def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor): + """Apply residual scaling with optional activation checkpointing.""" + hidden_states = (hidden_states + self.hidden_states_bias.expand(1, 1, -1)) * self.hidden_states_scale.expand( + 1, 1, -1 + ) + if self.not_first_layer: + residual = (residual + self.residual_bias.expand(1, 1, -1)) * self.residual_scale.expand(1, 1, -1) + return residual, hidden_states + + +class ZayaRouter(nn.Module): + """ + Key features: + - Down-projection to `mlp_expansion` then RMSNorm. + - Optional EDA (depth-wise averaging) via `router_states_scale` and prior `router_states`. + - Three-layer MLP with GELU producing `num_experts` logits per token. + - Top-k expert selection with balancing biases and MOD (skip expert) handling. + + Returns (route_prob, expert_choice_t, router_hidden_states_next) where: + - route_prob: (batch*seq, topk) gathered probabilities of chosen experts + - expert_choice_t: (batch*seq, topk) chosen expert indices (with MOD post-processing) + - router_hidden_states_next: the pre-norm router hidden states (B, S, mlp_expansion), + for feeding forward to the MoE layer. + """ + + def __init__( + self, + config, + layer_n: int, + num_moe_experts: int, + moe_router_topk: int, + mlp_expansion: int, + hidden_size: Optional[int] = None, + layer_number: Optional[int] = None, + ) -> None: + super().__init__() + + self.debug_level = 4 + # ---- Config / shape ---- + self.config = config + self.layer_n = layer_n + self.hidden_size = int(hidden_size or getattr(config, "hidden_size")) + self.layer_number = layer_number if layer_number is not None else 0 + + # MOD + self.use_mod = bool(getattr(config, "zaya_use_mod", False)) + + # Expert counts (extra 'skip' expert if MOD) + self.num_experts = (num_moe_experts + 1) if self.use_mod else num_moe_experts + self.topk = int(moe_router_topk) + + # Router hidden dim + self.mlp_expansion = int(mlp_expansion) + + # ---- Layers ---- + self.down_proj = nn.Linear(self.hidden_size, self.mlp_expansion, bias=True) + + # EDA (depth-wise averaging) + zaya_first_layer = 1 + use_eda_cfg = bool(getattr(config, "zaya_use_eda", False)) + self.use_eda = use_eda_cfg and (zaya_first_layer is not None) and (self.layer_number != zaya_first_layer) + + ln_eps = float(getattr(config, "layernorm_epsilon", 1e-6)) + self.rmsnorm_eda = ZayaRMSNorm(self.mlp_expansion, eps=ln_eps) + if self.use_eda: + # eda + self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) + + # routermlp + D = self.mlp_expansion + E = self.num_experts + self.non_linearity = nn.GELU() + self.router_mlp = nn.Sequential( + nn.Linear(D, D, bias=True), + self.non_linearity, + nn.Linear(D, D, bias=True), + self.non_linearity, + nn.Linear(D, E, bias=False), + ) + + # Balancing biases + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + if self.use_mod: + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, # (B, S, H) + router_states: Optional[torch.Tensor] = None, # (B, S, D) previous router states for EDA + eos_mask: Optional[torch.Tensor] = None, # unused here; kept for API compatibility + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute per-token expert probabilities and choose top-k experts. + + Args: + hidden_states: (batch, seq, hidden_size) + router_states: (batch, seq, mlp_expansion) from prior step/layer (for EDA). Optional. + eos_mask: kept for API compatibility; not used. + + Returns: + route_prob: (batch*seq, topk) + expert_choice_t: (batch*seq, topk) int64 + router_hidden_states_next: (batch, seq, mlp_expansion) + """ + B, S, _ = hidden_states.shape + + # eda + hs = self.down_proj(hidden_states) + + if self.use_eda and (router_states is not None): + hs = hs + router_states * self.router_states_scale + + # Stash the pre-norm states for the caller + router_hidden_states_next = hs[:, -S:].clone() + + # 2) RMSNorm eda + hs_norm = self.rmsnorm_eda(hs) + + # 3) Expert probability distribution + logits = self.router_mlp(hs_norm) + expert_prob = torch.softmax(logits, dim=-1) + + # 4) expert choice with balancing biases (biases affect choice only, + # not the probabilities) + biased = expert_prob.detach().to(torch.float32) + self.balancing_biases + _, expert_choice_t = torch.topk(biased, self.topk, dim=-1) # (B, S, topk) + + # 5) If MOD and topk>1, once skip expert is selected, force all + # subsequent choices to skip as well, but this never happens since we + # use topk=1 + if (self.topk > 1) and self.use_mod: + skip_idx = self.num_experts - 1 + n_mask = expert_choice_t == skip_idx + cumsum_mask = torch.cumsum(n_mask, dim=-1) + expert_choice_t = expert_choice_t.masked_fill(cumsum_mask > 0, skip_idx) + + # Gather the probabilities for the selected experts + route_prob = torch.gather(expert_prob, dim=2, index=expert_choice_t) + expert_choice_flat = expert_choice_t.reshape(-1, self.topk) + route_prob_flat = route_prob.reshape(-1, self.topk) + + return route_prob_flat, expert_choice_flat, router_hidden_states_next + + +class MLP(nn.Module): + """ + MLP will take the input with h hidden state, project it to another + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + + Returns an output and a bias to be added to the output. + If config.add_bias_linear is False, the bias returned is None. + + We use the following notation: + h: hidden size + p: number of tensor model parallel partitions + b: batch size + s: sequence length + """ + + def __init__(self, config, ffn_hidden_size): + super().__init__() + self.config = config + + # Double the output width with gated linear unit, see + # https://arxiv.org/pdf/2002.05202.pdf + if self.config.gated_linear_unit: + ffn_hidden_size_out = ffn_hidden_size // 2 + else: + ffn_hidden_size_out = ffn_hidden_size + + # Set the activation function. + if self.config.activation_func == "swiglu": + self.activation_func = F.silu + else: + self.activation_func = F.gelu + + self.linear_fc1 = nn.Linear( + in_features=self.config.hidden_size, + out_features=ffn_hidden_size, + bias=self.config.add_bias_linear, + ) + self.linear_fc2 = nn.Linear( + in_features=ffn_hidden_size_out, + out_features=self.config.hidden_size, + bias=self.config.add_bias_linear, + ) + + def forward(self, hidden_states): + # [s, b, 4 * h/p] + if self.config.add_bias_linear: + intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) + else: + intermediate_parallel = self.linear_fc1(hidden_states) + bias_parallel = None + + if self.config.bias_activation_fusion: + if self.activation_func == F.silu and self.config.gated_linear_unit: + intermediate_parallel = bias_swiglu_impl( + intermediate_parallel, + bias_parallel, + self.config.activation_func_fp8_input_store, + ) + else: + raise ValueError("Only support fusion of swiglu") + else: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + if self.config.gated_linear_unit: + + def glu(x): + x = torch.chunk(x, 2, dim=-1) + return self.config.activation_func(x[0]) * x[1] + + intermediate_parallel = glu(intermediate_parallel) + else: + intermediate_parallel = self.activation_func(intermediate_parallel) + + # [s, b, h] + if self.config.add_bias_linear: + output, output_bias = self.linear_fc2(intermediate_parallel) + else: + output = self.linear_fc2(intermediate_parallel) + output_bias = None + + return output, output_bias + + +class SequentialMLP(nn.Module): + """An implementation of the Experts layer using a sequence of MLP layers. + This class executes each expert sequentially. + """ + + def __init__(self, num_local_experts: int, config, ffn_hidden_size: int): + super().__init__() + self.config = config + self.add_bias = config.add_bias_linear + self.num_local_experts = num_local_experts + self.local_experts = torch.nn.ModuleList() + + for _ in range(self.num_local_experts): + expert = MLP(config=self.config, ffn_hidden_size=ffn_hidden_size) + self.local_experts.append(expert) + + def forward( + self, + permuted_local_hidden_states: torch.Tensor, + tokens_per_expert: torch.Tensor, + ): + """Forward step of the SequentialMLP.""" + if self.num_local_experts == 1: + output, output_bias = self.local_experts[0](permuted_local_hidden_states) + return output, output_bias + else: + tokens_per_expert = tokens_per_expert.tolist() + tokens_list = torch.split(permuted_local_hidden_states, tokens_per_expert) + + output_local_list = [] + output_bias_list = [] + + for expert, tokens in zip(self.local_experts, tokens_list): + output, output_bias = expert(tokens) + output_local_list.append(output) + if self.add_bias: + output_bias_list.append(output_bias.expand_as(output)) + + output_local = torch.cat(output_local_list, dim=0) + if self.add_bias: + output_bias_local = torch.cat(output_bias_list, dim=0) + else: + output_bias_local = None + + return output_local, output_bias_local + + +class ZayaBlock(nn.Module): + def __init__( + self, + config, + layer_idx: int, + mlp_expansion: int, + ffn_hidden_size: int, + first_mlp_layer: bool, + layer_n: int, + training: bool, + ): + super().__init__() + self.debug_level = 3 + self.config = config + self.layer_n = layer_n + self.training = training + self.hidden_dim = config.hidden_size + self.num_moe_experts = layer_idx + self.mlp_expansion = mlp_expansion + self.first_mlp_layer = first_mlp_layer + self.router = ZayaRouter( + config=self.config, + layer_n=layer_n, + num_moe_experts=self.num_moe_experts, + moe_router_topk=getattr(self.config, "moe_router_topk", 1), + mlp_expansion=mlp_expansion, + hidden_size=self.hidden_dim, + layer_number=layer_n, + ) + self.experts = SequentialMLP(self.num_moe_experts, self.config, ffn_hidden_size=ffn_hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: Optional[torch.Tensor] = None, + past_router_states: Optional[torch.Tensor] = None, + use_cache=False, + cca_mask: Optional[torch.Tensor] = None, + ): + route_prob, expert_choice, prev_router_hidden_states = self.router( + hidden_states, router_states=prev_router_hidden_states + ) + probs = route_prob + indices = expert_choice + batch_size, seq_length, emb_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, emb_dim) + indices_flat = indices.view(batch_size * seq_length) + sorted_indices, sort_order = torch.sort(indices_flat) + tokens_per_expert = torch.bincount(sorted_indices, minlength=self.router.num_experts) + sorted_hidden_states = hidden_states_flat[sort_order] + original_order = torch.argsort(sort_order) + + if self.config.zaya_use_mod: + expert_output, mlp_bias = self.experts( + sorted_hidden_states[: sum(tokens_per_expert[:-1])], + tokens_per_expert[:-1], + ) + expert_output = torch.cat( + [expert_output, sorted_hidden_states[sum(tokens_per_expert[:-1]) :]], + dim=0, + ) + if mlp_bias is not None: + mlp_bias = torch.cat( + [ + mlp_bias, + torch.zeros_like(sorted_hidden_states[sum(tokens_per_expert[:-1]) :]), + ], + dim=0, + ) + else: + expert_output, mlp_bias = self.experts(sorted_hidden_states, tokens_per_expert) + + expert_output = expert_output[original_order] + expert_output = expert_output.view(batch_size, seq_length, emb_dim) + # print(probs.shape,expert_output.shape) + probs = probs.view(batch_size, seq_length) + expert_output = expert_output * probs.unsqueeze(-1) + + if mlp_bias is not None: + mlp_bias = mlp_bias[original_order] + mlp_bias = mlp_bias.view(batch_size, seq_length, emb_dim) + + return expert_output, mlp_bias, prev_router_hidden_states + + +class ZayaDecoderMLPLayer(nn.Module): + def __init__( + self, + config: ZayaConfig, + layer_idx: int, + mlp_expansion: int, + ffn_hidden_size: int, + first_mlp_layer: bool, + layer_n: int, + training: bool, + ): + super().__init__() + self.debug_level = 2 + self.config = config + self.layer_n = layer_n + self.training = training + self.zaya_block = ZayaBlock( + config, + layer_idx, + mlp_expansion, + ffn_hidden_size, + first_mlp_layer, + layer_n, + self.training, + ) + self.input_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + + if self.config.scale_residual_merge: + self.res_scale = ResidualScaling(config, layer_n) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cca_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + prev_router_hidden_states: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + residual (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + past_key_values (`tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`torch.FloatTensor`): Positional embedding used. + prev_router_hidden_states (`torch.FloatTensor`): Activations from the previous router to do DWA. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + if self.config.scale_residual_merge: + residual, hidden_states = self.res_scale(residual, hidden_states) + if residual is None: + residual = hidden_states.to(torch.float32) if (self.config.residual_in_fp32) else hidden_states + else: + residual = hidden_states + residual + hidden_states = self.input_norm(residual.to(dtype=self.input_norm.weight.dtype)) + + if self.config.add_bias_linear: + hidden_states, bias_states, prev_router_hidden_states = self.zaya_block( + hidden_states, + prev_router_hidden_states, + past_key_values, + use_cache, + cca_mask, + ) + hidden_states = hidden_states + bias_states + else: + hidden_states, _, prev_router_hidden_states = self.zaya_block( + hidden_states, + prev_router_hidden_states, + past_key_values, + use_cache, + cca_mask, + ) + + outputs = (hidden_states,) + + return outputs, residual, prev_router_hidden_states + + +Zaya_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`ZayaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Zaya Model outputting raw hidden-states without any specific head on top.", + Zaya_START_DOCSTRING, +) +class ZayaPreTrainedModel(PreTrainedModel): + config_class = ZayaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ZayaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = False + # MoE models don't work with torch.compile (`torch.where(condition)` not + # supported) + _supports_static_cache = False + + +Zaya_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Zaya Model outputting raw hidden-states without any specific head on top.", + Zaya_START_DOCSTRING, +) +class ZayaModel(ZayaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer can be an attention layer ZayaDecoderATTLayer or an MLP layer ZayaDecoderMLPLayer. + Args: + config: ZayaConfig + """ + + def __init__(self, config: ZayaConfig): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self._attn_implementation = config._attn_implementation + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.padding_idx) + self.layers = [] + first_mlp_layer = True + + for layer_n in range(len(config.zaya_layers)): + if isinstance(config.zaya_layers[layer_n], int): + self.layers.append( + ZayaDecoderMLPLayer( + config, + config.zaya_layers[layer_n], + config.zaya_mlp_expansion[layer_n], + config.ffn_hidden_size_list[layer_n], + first_mlp_layer, + layer_n, + self.training, + ) + ) + first_mlp_layer = False + else: + self.layers.append(ZayaDecoderATTLayer(config, layer_n, self.training)) + self.layers = nn.ModuleList(self.layers) + + self.gradient_checkpointing = False + + if self.config.scale_residual_merge: + self.res_scale = ResidualScaling(config, len(config.zaya_layers)) + + self.final_norm = ZayaRMSNorm(self.config.hidden_size, eps=self.config.norm_epsilon) + self.rotary_emb = ZayaRotaryEmbedding(config=config) + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(Zaya_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, MoeModelOutputWithPast]: + _, seq_length = input_ids.shape + + if attention_mask is not None: + cca_mask = attention_mask.clone() + else: + cca_mask = None + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = ZayaDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + residual = None + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + prev_router_hidden_states = None + + for layer_n, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs, residual, prev_router_hidden_states = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + residual, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + prev_router_hidden_states, + cca_mask, + ) + else: + layer_outputs, residual, prev_router_hidden_states = decoder_layer( + hidden_states, + residual, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + prev_router_hidden_states=prev_router_hidden_states, + cca_mask=cca_mask, + ) + + hidden_states = layer_outputs[0] + + if isinstance(decoder_layer, ZayaDecoderATTLayer): + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.config.scale_residual_merge: + residual, hidden_states = self.res_scale(residual, hidden_states) + + if residual is None: + residual = hidden_states.to(torch.float32) if (self.config.residual_in_fp32) else hidden_states + else: + residual = hidden_states + residual + + hidden_states = self.final_norm(residual.to(dtype=self.final_norm.weight.dtype)) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from + # transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask + # with Phi3->Zaya + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Zaya. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method + # calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal + # mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from + # transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position + # with Mistral->Zaya + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: ZayaConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`ZayaConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted + # form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was + # trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.config = config + self.tie_word_embeddings = self.config.tie_word_embeddings # so the linter stops complaining + self.model = ZayaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) + self.post_init() + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from + # transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(Zaya_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, ZayaForCausalLM + >>> model = ZayaForCausalLM.from_pretrained("Zyphra/Zaya-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zaya-8B") + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + if ( + use_cache + and self.config.rope_scaling + and "original_max_position_embeddings" in self.config.rope_scaling + and cache_position is not None + and cache_position[0] == self.config.rope_scaling["original_max_position_embeddings"] + ): + logger.warning( + f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.rope_scaling['original_max_position_embeddings']}th token, as the KV cache needs to be recomputed." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from + # transformers.models.phi3.modeling_phi3.Phi3ForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwitten -- has a unique cache type, `ZayaDynamicCache` + if past_key_values is not None and not isinstance(past_key_values, ZayaDynamicCache): + raise ValueError( + f"Zaya uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}." + ) + empty_past_kv = past_key_values is None + + # Omit tokens covered by past_key_values + if not empty_past_kv: + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if inputs_embeds is not None or ( # Exception 1 + is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] + ): # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = ZayaDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st + # generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": logits_to_keep, + "cache_position": cache_position, + } + ) + + return model_inputs + + def _prepare_cache_for_generation( + self, + generation_config, + model_kwargs: dict, + generation_mode, + batch_size: int, + max_cache_length: int, + ): + if "past_key_values" not in model_kwargs: + model_kwargs["past_key_values"] = ZayaDynamicCache( + self.config, batch_size, dtype=self.dtype, device=self.device + ) + generation_config.cache_implementation = None + return super()._prepare_cache_for_generation( + generation_config=generation_config, + model_kwargs=model_kwargs, + generation_mode=generation_mode, + batch_size=batch_size, + max_cache_length=max_cache_length, + ) + + +__all__ = [ + "ZayaPreTrainedModel", + "ZayaModel", + "ZayaForCausalLM", +] diff --git a/tests/models/zaya/__init__.py b/tests/models/zaya/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/zaya/test_modeling_zaya.py b/tests/models/zaya/test_modeling_zaya.py new file mode 100644 index 000000000000..f5a08726e357 --- /dev/null +++ b/tests/models/zaya/test_modeling_zaya.py @@ -0,0 +1,172 @@ +# Copyright 2025 Zaya and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch + + +if is_torch_available(): + from transformers import ZayaForCausalLM, ZayaModel + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class ZayaModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = ZayaModel + + def __init__( + self, + parent, + cca=True, + cca_num_q_heads=[8, 0, 8, 0], + num_query_groups_list=[2, 0, 2, 0], + use_cache=False, + attention_bias=False, + lm_head_bias=False, + ffn_hidden_size_list=[0, 64, 0, 64], + activation_func="swiglu", + norm_epsilon=1e-05, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + tie_word_embeddings=True, + rope_theta=10000, + attention_dropout=0.0, + moe_router_topk=1, + zaya_layers=["a", 16, "a", 16], + normalization="RMSNorm", + zaya_mlp_expansion=[0, 4, 0, 4], + zaya_use_mod=True, + zaya_high_prec=True, + zaya_use_eda=True, + add_bias_linear=False, + gated_linear_unit=True, + scale_residual_merge=True, + fused_add_norm=False, + residual_in_fp32=False, + apply_rope_fusion=True, + bias_activation_fusion=True, + activation_func_fp8_input_store=False, + sliding_window=None, + rope_scaling=None, + rope_parameters=None, + partial_rotary_factor=0.5, + _attn_implementation="eager", + # CausalLMModelTester args + batch_size=4, + seq_length=128, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=64, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=16, + num_key_value_heads=2, + intermediate_size=64, + hidden_act="silu", + max_position_embeddings=16384, + initializer_range=0.02, + ): + super().__init__( + parent=parent, + batch_size=batch_size, + seq_length=seq_length, + is_training=is_training, + use_input_mask=use_input_mask, + use_token_type_ids=use_token_type_ids, + use_labels=use_labels, + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + ) + + self.cca = cca + self.cca_num_q_heads = cca_num_q_heads + self.num_query_groups_list = num_query_groups_list + self.use_cache = use_cache + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size_list = ffn_hidden_size_list + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.kv_channels = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.activation_func = activation_func + self.max_position_embeddings = max_position_embeddings + self.norm_epsilon = norm_epsilon + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.attention_dropout = attention_dropout + self.moe_router_topk = moe_router_topk + self.zaya_layers = zaya_layers + self.zaya_mlp_expansion = zaya_mlp_expansion + self.zaya_use_mod = zaya_use_mod + self.zaya_use_eda = zaya_use_eda + self.add_bias_linear = add_bias_linear + self.gated_linear_unit = gated_linear_unit + self.scale_residual_merge = scale_residual_merge + self.residual_in_fp32 = residual_in_fp32 + self.bias_activation_fusion = bias_activation_fusion + self.activation_func_fp8_input_store = activation_func_fp8_input_store + self.sliding_window = sliding_window + self.partial_rotary_factor = partial_rotary_factor + self.num_key_value_heads = num_key_value_heads + self._attn_implementation = _attn_implementation + self.rope_parameters = { + "rope_theta": rope_theta, + "rope_type": "linear", + "factor": 1, + } + + +@require_torch +class ZayaModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = ZayaModelTester + all_model_classes = (ZayaModel, ZayaForCausalLM) if is_torch_available() else () + pipeline_model_mapping = ( + {"feature-extraction": ZayaModel, "text-generation": ZayaForCausalLM} if is_torch_available() else {} + ) + + @unittest.skip("Zaya applies key/query norm which doesn't work with packing") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Zaya applies key/query norm which doesn't work with packing") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Zaya applies key/query norm which doesn't work with packing") + def test_model_rope_scaling_frequencies(self): + pass + + @unittest.skip("Zaya has moe, output can be different") + def test_model_outputs_equivalence(self, **kwargs): + pass + + # TODO: Add integration tests once we have a checkpoint on the Hub diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 34ebf37b2fd0..45fcf1c17e70 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -1509,6 +1509,9 @@ def check_docstrings(overwrite: bool = False, check_all: bool = False): hard_failures.append(name) continue if old_doc != new_doc: + import pdb + + pdb.set_trace() if overwrite: fix_docstring(obj, old_doc, new_doc) else: