From ce9f498ddb16ec0ca0efe5fdeec9fa8142f59507 Mon Sep 17 00:00:00 2001 From: Jialei Chen Date: Tue, 12 Aug 2025 19:03:15 +0000 Subject: [PATCH 1/2] ini --- .../configs/model/flex-qwen-1b.yaml | 34 + .../configs/model/remat/qwen.yaml | 7 + .../configs/model/sharding/qwen-fsdp.yaml | 17 + torchprime/torch_xla_models/flex/__init__.py | 5 + torchprime/torch_xla_models/flex/attention.py | 235 +++++++ .../torch_xla_models/flex/modeling_qwen.py | 603 ++++++++++++++++++ 6 files changed, 901 insertions(+) create mode 100644 torchprime/torch_xla_models/configs/model/flex-qwen-1b.yaml create mode 100644 torchprime/torch_xla_models/configs/model/remat/qwen.yaml create mode 100644 torchprime/torch_xla_models/configs/model/sharding/qwen-fsdp.yaml create mode 100644 torchprime/torch_xla_models/flex/__init__.py create mode 100644 torchprime/torch_xla_models/flex/attention.py create mode 100644 torchprime/torch_xla_models/flex/modeling_qwen.py diff --git a/torchprime/torch_xla_models/configs/model/flex-qwen-1b.yaml b/torchprime/torch_xla_models/configs/model/flex-qwen-1b.yaml new file mode 100644 index 00000000..9b242b6a --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/flex-qwen-1b.yaml @@ -0,0 +1,34 @@ +defaults: + - _self_ # refers to this config file + - sharding: qwen-fsdp # refers to sharding/qwen-fsdp.yaml + - remat: qwen # refers to remat/qwen.yaml + +model_class: flex.Qwen3ForCausalLM # Used to import the model from this class +attention_bias: false +attention_dropout: 0.0 +bos_token_id: 151643 +eos_token_id: 151645 +pad_token_id: 151643 +mask_token_id: 151669 +tokenizer_name: Qwen/Qwen3-1.7B +head_dim: 128 +hidden_act: silu +hidden_size: 2048 +initializer_range: 0.02 +intermediate_size: 6144 +max_position_embeddings: 40960 +max_window_layers: 28 +num_attention_heads: 16 +num_hidden_layers: 28 +num_key_value_heads: 8 +rms_norm_eps: 1e-06 +rope_scaling: null +rope_theta: 1000000 +sliding_window: null +tie_word_embeddings: true +torch_dtype: bfloat16 +use_cache: true +use_sliding_window: false +vocab_size: 151936 +# choose attention_kernel from: [flash_attention, splash_attention, null] +attention_kernel: flash_attention \ No newline at end of file diff --git a/torchprime/torch_xla_models/configs/model/remat/qwen.yaml b/torchprime/torch_xla_models/configs/model/remat/qwen.yaml new file mode 100644 index 00000000..0b8d1414 --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/remat/qwen.yaml @@ -0,0 +1,7 @@ +activation_checkpoint_layers: + - Qwen3DecoderLayer + +# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info. +optimization_barrier_layers: + - Qwen3DecoderLayer + \ No newline at end of file diff --git a/torchprime/torch_xla_models/configs/model/sharding/qwen-fsdp.yaml b/torchprime/torch_xla_models/configs/model/sharding/qwen-fsdp.yaml new file mode 100644 index 00000000..65d22df3 --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/sharding/qwen-fsdp.yaml @@ -0,0 +1,17 @@ +# Weights +model.embed_tokens.weight: [fsdp, tensor] +model.layers.*.self_attn.q_proj.weight: [fsdp, tensor] +model.layers.*.self_attn.k_proj.weight: [tensor, fsdp] +model.layers.*.self_attn.v_proj.weight: [tensor, fsdp] +model.layers.*.self_attn.o_proj.weight: [fsdp, tensor] +model.layers.*.mlp.gate_proj.weight: [fsdp, tensor] +model.layers.*.mlp.up_proj.weight: [fsdp, tensor] +model.layers.*.mlp.down_proj.weight: [tensor, fsdp] +model.layers.*.input_layernorm.weight: [fsdp] +model.layers.*.post_attention_layernorm.weight: [fsdp] +model.norm.weight: [fsdp] +lm_head.weight: [fsdp, tensor] + +# Activations +model.layers.*: [fsdp, null, tensor] +lm_head: [fsdp, null, tensor] \ No newline at end of file diff --git a/torchprime/torch_xla_models/flex/__init__.py b/torchprime/torch_xla_models/flex/__init__.py new file mode 100644 index 00000000..1b3ac7dc --- /dev/null +++ b/torchprime/torch_xla_models/flex/__init__.py @@ -0,0 +1,5 @@ +# from .modeling_llama import LlamaForCausalLM +from .modeling_qwen import Qwen3ForCausalLM + + +__all__ = ["Qwen3ForCausalLM"] \ No newline at end of file diff --git a/torchprime/torch_xla_models/flex/attention.py b/torchprime/torch_xla_models/flex/attention.py new file mode 100644 index 00000000..d525514f --- /dev/null +++ b/torchprime/torch_xla_models/flex/attention.py @@ -0,0 +1,235 @@ +import math +from typing import Any + +import torch +from torch import nn + +# Detect if we're on TPU (no CUDA available) vs GPU +IS_TPU = not torch.cuda.is_available() + +if IS_TPU: + # TPU environment - use original torch_xla imports + import torch_xla.debug.profiler as xp + import torch_xla.distributed.spmd as xs + from torch_xla.experimental.custom_kernel import FlashAttention, flash_attention + from torch_xla.experimental.splash_attention import ( + SplashAttentionConfig, + splash_attention, + ) +else: + # GPU environment - use PyTorch's native SDPA flash attention + from torch.nn.functional import scaled_dot_product_attention + from torch.nn.attention import SDPBackend, sdpa_kernel + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class AttentionModule(nn.Module): + def __init__(self, config, kernel_config: dict[str, Any] | None = None): + super().__init__() + self.config = config + self.kernel_config = kernel_config + + def _forward_tpu( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ): + """Original TPU/XLA implementation""" + + if self.config.attention_kernel != "splash_attention": + num_key_value_groups = ( + self.config.num_attention_heads // self.config.num_key_value_heads + ) + key_states = repeat_kv(key_states, num_key_value_groups) + value_states = repeat_kv(value_states, num_key_value_groups) + + bsz, num_heads, q_len, head_dim = query_states.size() + # TODO: q, k dim unintentionally changed after the apply_rotary_pos_emb. Use + # v's dim temporarily to bypass shape assertion failure. Remove the + # following line after resolving + # https://github.com/AI-Hypercomputer/torchprime/issues/195. + head_dim = value_states.shape[-1] + + kv_seq_len = key_states.shape[-2] + + # Non FA path doesn't deal with 2D sharding. + self.partition_spec = None + segment_ids_partition_spec = None + if xs.get_global_mesh() is not None: + self.partition_spec = (("data", "fsdp"), "tensor", None, None) + segment_ids_partition_spec = (("data", "fsdp"), None) + + match self.config.attention_kernel: + case "splash_attention": + raise NotImplementedError("Splash Attention is not supported yet") + # Integrated with PyTorch/XLA Pallas Splash Attention: + assert xs.get_global_mesh() is not None, ( + "Global mesh is required for Splash Attention" + ) + sa_config = SplashAttentionConfig( + mesh=str(xs.get_global_mesh()), + qkv_partition_spec=self.partition_spec, + segment_ids_partition_spec=segment_ids_partition_spec, + ) + if self.kernel_config is not None: + for key, value in self.kernel_config.items(): + if hasattr(sa_config, key): + setattr(sa_config, key, value) + query_states /= math.sqrt(head_dim) + attn_output = splash_attention( + query_states, key_states, value_states, sa_config.to_json() + ) + case "flash_attention": + # Integrated with PyTorch/XLA Pallas Flash Attention: + default_block_sizes = { + "block_q": 2048, + "block_k_major": 512, + "block_k": 512, + "block_b": 2, + "block_q_major_dkv": 2048, + "block_k_major_dkv": 512, + "block_q_dkv": 2048, + "block_k_dkv": 512, + "block_q_dq": 2048, + "block_k_dq": 256, + "block_k_major_dq": 512, + } + if self.kernel_config is not None: + default_block_sizes.update(self.kernel_config) + FlashAttention.DEFAULT_BLOCK_SIZES = default_block_sizes + + query_states /= math.sqrt(head_dim) + attn_output = flash_attention( + query_states, + key_states, + value_states, + causal=False, # weiran: causal=False for bi-directional attention + partition_spec=self.partition_spec, + ) + case "default" | None: + # Default attention implementation (no flash attention) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(head_dim) + if attn_weights.size() != (bsz, num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.config.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + case _: + raise NotImplementedError(f"Attention kernel {self.config.attention_kernel} is not supported yet") + + if attn_output.size() != (bsz, num_heads, q_len, head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" + f" {attn_output.size()}" + ) + return attn_output + + def _forward_gpu( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ): + """GPU-optimized PyTorch implementation""" + if self.config.attention_kernel != "splash_attention": + num_key_value_groups = ( + self.config.num_attention_heads // self.config.num_key_value_heads + ) + key_states = repeat_kv(key_states, num_key_value_groups) + value_states = repeat_kv(value_states, num_key_value_groups) + + bsz, num_heads, q_len, head_dim = query_states.size() + # TODO: q, k dim unintentionally changed after the apply_rotary_pos_emb. Use + # v's dim temporarily to bypass shape assertion failure. Remove the + # following line after resolving + # https://github.com/AI-Hypercomputer/torchprime/issues/195. + head_dim = value_states.shape[-1] + + kv_seq_len = key_states.shape[-2] + + # Use SDPA with appropriate backend + + match self.config.attention_kernel: + case "splash_attention": + raise NotImplementedError("Splash Attention is not supported in GPU environment") + + case "flash_attention": + # Try to use flash attention backend, fallback to default if not available + try: + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + dropout_p=self.config.attention_dropout if self.training else 0.0, + is_causal=False, # weiran: causal=False for bi-directional attention + ) + except (RuntimeError, NotImplementedError): + # Flash attention not available, use default backend + with sdpa_kernel(SDPBackend.MATH): + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + dropout_p=self.config.attention_dropout if self.training else 0.0, + is_causal=False, # weiran: causal=False for bi-directional attention + ) + + case _: + # Default implementation - use math backend for compatibility + with sdpa_kernel(SDPBackend.MATH): + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + dropout_p=self.config.attention_dropout if self.training else 0.0, + is_causal=False, # weiran: causal=False for bi-directional attention + ) + + if attn_output.size() != (bsz, num_heads, q_len, head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" + f" {attn_output.size()}" + ) + return attn_output + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ): + if IS_TPU: + return self._forward_tpu(query_states, key_states, value_states, attention_mask) + else: + return self._forward_gpu(query_states, key_states, value_states, attention_mask) \ No newline at end of file diff --git a/torchprime/torch_xla_models/flex/modeling_qwen.py b/torchprime/torch_xla_models/flex/modeling_qwen.py new file mode 100644 index 00000000..7ee22795 --- /dev/null +++ b/torchprime/torch_xla_models/flex/modeling_qwen.py @@ -0,0 +1,603 @@ +""" +PyTorch Qwen3 model. +Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py +""" +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from omegaconf import DictConfig +from torch import nn +from transformers.activations import ACT2FN +from transformers.utils import logging + +from torchprime.layers.sequential import HomogeneousSequential +from torchprime.rope.rope import RopeScaling, default_rope_frequencies +IS_TPU = not torch.cuda.is_available() +if IS_TPU: + from torchprime.torch_xla_models import offloading +from torchprime.torch_xla_models.flex.attention import AttentionModule + +logger = logging.get_logger(__name__) + +class Qwen3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + +class Qwen3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DictConfig, layer_idx: int | None = None): + super().__init__() + self.config = config + self.attention_block = AttentionModule(config) + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = getattr(config, "attention_dropout", 0.0) + # weiran: diffullama + self.is_causal = False + + self.q_proj = nn.Linear( + self.hidden_size, + self.num_heads * self.head_dim, + bias=getattr(config, "attention_bias", False) + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=getattr(config, "attention_bias", False) + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=getattr(config, "attention_bias", False) + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, "attention_bias", False) + ) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=getattr(config, "rms_norm_eps", 1e-6)) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=getattr(config, "rms_norm_eps", 1e-6)) + + # Handle sliding window - check if layer_types exists and if this layer should use sliding attention + if not config.use_sliding_window: + self.sliding_window = None + else: + raise NotImplementedError("Sliding window is not implemented for Qwen3") + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + ) -> torch.FloatTensor: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Apply q_norm and k_norm to the head dimension + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + # Apply normalization + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + # Transpose to get the right shape for attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attn_output = self.attention_block( + query_states, key_states, value_states, attention_mask + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output + + +class Qwen3RotaryEmbedding(nn.Module): + inv_freq: nn.Buffer + + def __init__( + self, + head_dim, + rope_theta, + scaling: RopeScaling | None = None, + ): + super().__init__() + if scaling is None: + inv_freq = default_rope_frequencies(head_dim, theta=rope_theta) + else: + raise NotImplementedError("Scaling is not implemented for Qwen3") + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( + 1, 2 + ) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3DecoderLayer(nn.Module): + def __init__(self, config: DictConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)) + self.post_attention_layernorm = Qwen3RMSNorm( + config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6) + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor]| None = None, # necessary, but kept here for BC + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + """ + # This gives the `hidden_states` tensor a name so that we can layer specify + # to offload this tensor to host RAM to save memory. This is not a standard + # torch API because there is no such feature in PyTorch. Instead, the name + # becomes node metadata during FX graph capture. + if IS_TPU: + hidden_states = offloading.offload_name(hidden_states, "decoder_input") + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3Model(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`] + + Args: + config: DictConfig + """ + + def __init__(self, config: DictConfig): + super().__init__() + self.vocab_size = config.vocab_size + if "pad_token_id" not in config: + self.padding_idx = None + else: + self.padding_idx = config.pad_token_id + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) + # `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with + # `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html. + self.layers = HomogeneousSequential( + *[ + Qwen3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)) + + rope_scaling = config.get("rope_scaling", None) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.rope_theta = getattr(config, "rope_theta", 10000.0) + if rope_scaling is not None: + rope_scaling = RopeScaling(**rope_scaling) + self.rotary_emb = Qwen3RotaryEmbedding( + head_dim=head_dim, rope_theta=self.rope_theta, scaling=rope_scaling + ) + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.FloatTensor | None = None, + ) -> torch.Tensor: + # convert input ids to embeddings + inputs_embeds = self.embed_tokens(input_ids) + + seq_length = inputs_embeds.size(1) + + # TODO(https://github.com/pytorch/xla/issues/8783): Pass position_ids as `long()` + # when `scan` can take non-differentiable inputs. + position_ids = ( + torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0).float() + ) + + # Create a causal attention mask + causal_mask = torch.triu( + torch.full((seq_length, seq_length), float("-inf"), device=inputs_embeds.device), + diagonal=1, + ) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimension + + if attention_mask is not None: + causal_mask = causal_mask * attention_mask[:, None, None, :] + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + hidden_states = self.layers( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + ) + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class Qwen3ForCausalLM(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.model = Qwen3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.mask_token_id = config.mask_token_id + + # Initialize weights and apply final processing + self.apply(self._init_weights) + + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", 0.02) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def forward( + self, + input_ids: torch.LongTensor, + labels: torch.LongTensor | None = None, + attention_mask: torch.FloatTensor | None = None, + src_mask: torch.BoolTensor | None = None, + training_mode: str = "pretrain", + ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]: + if not self.training: + # haolin: during inference the masking is done when preprocessing the input, we don't need src_mask and noising + model_output = self.model(input_ids=input_ids, attention_mask=attention_mask) + hidden_states = model_output + logits = self.lm_head(hidden_states) # NOTE: we shift logits in generate() + # logits = logits.float()[..., :-1, :].contiguous() # NOTE: we shift logits in inference_utils at inference time + return logits, None + + if training_mode == "sft" and src_mask is None: + raise ValueError("SFT mode requires a non-null src_mask") + + # weiran: diffullama + sampling_eps = 1e-3 + mask_token_id = self.mask_token_id + loss_func = nn.CrossEntropyLoss(reduction="none") + batch_size, seq_len = input_ids.shape # input_ids: [batch_size, seq_len] + + # Create maskable_mask based on training mode and src_mask + # For SFT: src_mask is provided, maskable_mask = ~src_mask + # For pretrain: src_mask is None, maskable_mask = all True + if src_mask is not None: # SFT + maskable_mask = ~src_mask + else: # pretrain or midtrain + maskable_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device) + prefix_probability = getattr(self.config, "prefix_probability", 0) + truncate_probability = getattr(self.config, "truncate_probability", 0) + # Generate random decisions for all batch items + apply_prefix = torch.rand(batch_size, device=input_ids.device) < prefix_probability + # Only apply truncation to rows that are NOT prefixed + apply_truncate = torch.rand(batch_size, device=input_ids.device) < truncate_probability + apply_truncate = apply_truncate & ~apply_prefix + + if prefix_probability > 0: + maskable_mask = prefix_input_ids(input_ids, maskable_mask, apply_prefix) + if truncate_probability > 0: + input_ids = truncate_input_ids(input_ids, apply_truncate, self.config.pad_token_id) + maskable_mask = maskable_mask & (input_ids != self.config.pad_token_id) # NOTE: necessary? + + # add noise to input_ids + sigma = (1 - sampling_eps) * torch.rand(input_ids.shape[0], device=input_ids.device) + sampling_eps + dsigma = torch.reciprocal(sigma) + + # Sample mask block size + mask_block_sizes = getattr(self.config, "mask_block_sizes", None) + block_masking_probability = getattr(self.config, "block_masking_probability", 0) + if block_masking_probability > 0 and mask_block_sizes is not None: + mask_block_size = mask_block_sizes[torch.randint(0, len(mask_block_sizes), (1,)).item()] + else: + mask_block_size = 1 + + noisy_input_ids = transition( + input_ids, + sigma[:, None], + maskable_mask=maskable_mask, + mask_token_id=mask_token_id, + mask_block_size=mask_block_size + ) + loss_mask = noisy_input_ids == mask_token_id + + hidden_states = self.model(input_ids=noisy_input_ids, attention_mask=attention_mask) + # hidden_states = self.model(input_ids=input_ids, attention_mask=attention_mask) + logits = self.lm_head(hidden_states) + logits = logits.float() + # logits: [bs, seq_len, vocab_size] + # Shifted logits and labels + # logits: [bs, seq_len-1, vocab_size] + logits = logits[..., :-1, :].contiguous() + # weiran: if the shifted token is not masked in the original input, the loss is 0 + # loss_mask: [bs, seq_len-1] + loss_mask = loss_mask[..., 1:].contiguous() + target_ids = input_ids[..., 1:].contiguous() + # loss: [bs, seq_len-1] + loss = loss_func( + logits.reshape(-1, logits.shape[-1]), target_ids.reshape(-1) + ).reshape(target_ids.shape[0],-1) + loss = loss.masked_fill(~loss_mask, 0) + # weiran: divide by the number of tokens in the sequence instead of the number of masked tokens + # justification is dsigma already accounts for the number of masked tokens + # this is a hack to get something like per token loss + # https://github.com/ML-GSAI/SMDM/blob/main/pretrain/train_mdm_rl.py#L281-L283 + loss = (dsigma[:, None] * loss).sum() / (input_ids.shape[0] * input_ids.shape[1]) + return logits, loss + + +def transition( + x_0, sigma, maskable_mask, mask_token_id, mask_block_size: int = 1 +): + """Apply masking to input tokens. If mask_block_size > 1, use block masking for all rows.""" + + if mask_block_size == 1: + # Original behavior + # weiran: diffullama + move_indices = (torch.rand(*x_0.shape, device=x_0.device) < sigma) & maskable_mask + x_t = torch.where(move_indices, mask_token_id, x_0) + return x_t + + # Block masking for entire batch + return block_masking(x_0, sigma, maskable_mask, mask_token_id, mask_block_size) + + +def block_masking(x_0, sigma, maskable_mask, mask_token_id, mask_block_size): + """ + XLA-compatible block masking applied uniformly to all rows in the batch. + Uses efficient tensor operations to avoid dynamic loops. + """ + batch_size, seq_len = x_0.shape + + if seq_len < mask_block_size: + return x_0 + + # Calculate number of possible block positions + num_windows = seq_len - mask_block_size + 1 + + # Create all possible block positions: [num_windows, mask_block_size] + window_starts = torch.arange(num_windows, device=x_0.device) + block_offsets = torch.arange(mask_block_size, device=x_0.device) + all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0) + + # Check which blocks are fully maskable: [batch_size, num_windows] + maskable_blocks = maskable_mask.unsqueeze(1).expand(-1, num_windows, -1).gather( + 2, all_positions.unsqueeze(0).expand(batch_size, -1, -1) + ) + fully_maskable = maskable_blocks.all(dim=2) + + # Determine which blocks should be masked: (batch_size, num_windows) + effective_sigma = 1 - (1-sigma)**(1/mask_block_size) # NOTE: since we mask with blocks, we need to scale sigma by block size + should_mask = (torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma) & fully_maskable + + # Create final mask using simple broadcasting (fully XLA-compatible) + # For each position in the sequence, check if it's part of any masked block + position_indices = torch.arange(seq_len, device=x_0.device) # [seq_len] + + # Check for each position if it falls within any masked block + # position_indices: [seq_len] -> [1, 1, seq_len] + # all_positions: [num_windows, mask_block_size] -> [1, num_windows, mask_block_size] + # should_mask: [batch_size, num_windows] -> [batch_size, num_windows, 1] + + position_indices = position_indices.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len] + all_positions = all_positions.unsqueeze(0) # [1, num_windows, mask_block_size] + should_mask = should_mask.unsqueeze(2) # [batch_size, num_windows, 1] + + # Check if each position matches any of the positions in masked blocks + # [1, 1, seq_len] == [1, num_windows, mask_block_size] -> [1, num_windows, seq_len] + position_matches = (position_indices == all_positions.unsqueeze(3)).any(dim=2) # [1, num_windows, seq_len] + + # Apply should_mask to get final positions to mask + # [batch_size, num_windows, 1] & [1, num_windows, seq_len] -> [batch_size, num_windows, seq_len] + should_mask_positions = should_mask & position_matches + + # Reduce over windows: if any window masks this position, mask it + final_mask = should_mask_positions.any(dim=1) # [batch_size, seq_len] + + # Apply the mask + result = torch.where(final_mask, mask_token_id, x_0) + + return result + + +def prefix_input_ids(input_ids, maskable_mask, apply_prefix): + """Apply prefix to input_ids based on configured probability. Return a masksable mask such that the prefix is not masked.""" + batch_size, seq_len = input_ids.shape + # Generate random prefix lengths for all batch items + prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) + # Create position indices: [1, seq_len] + position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) # [1, seq_len] + # Create prefix mask: True where position < prefix_length + prefix_mask = position_indices < prefix_lengths.unsqueeze(1) # [batch_size, seq_len] + # Apply prefix masking: set to False where we should apply prefix masking + maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask) + return maskable_mask + + +def truncate_input_ids(input_ids, apply_truncate, pad_token_id): + """Truncate input_ids at random position and fill with pad token. Return the input_ids with suffix truncated and filled with pad token.""" + batch_size, seq_len = input_ids.shape + # Generate random truncation positions for all batch items + truncate_positions = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) + # Create position indices: [1, seq_len] + position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) # [1, seq_len] + # Create truncate mask: True where position >= truncate_position + truncate_mask = position_indices >= truncate_positions.unsqueeze(1) # [batch_size, seq_len] + # Apply truncation: fill with pad token where we should truncate + input_ids = torch.where(apply_truncate.unsqueeze(1) & truncate_mask, pad_token_id, input_ids) + return input_ids \ No newline at end of file From 5b4867bd217dedade016319adf9f9d9834fe66ee Mon Sep 17 00:00:00 2001 From: Jialei Chen Date: Tue, 19 Aug 2025 21:51:53 +0000 Subject: [PATCH 2/2] format --- torchprime/torch_xla_models/flex/__init__.py | 3 +- torchprime/torch_xla_models/flex/attention.py | 104 +-- .../torch_xla_models/flex/modeling_qwen.py | 273 ++++---- torchprime/torch_xla_models/flex/qwen_hf.py | 600 ++++++++++++++++++ 4 files changed, 773 insertions(+), 207 deletions(-) create mode 100644 torchprime/torch_xla_models/flex/qwen_hf.py diff --git a/torchprime/torch_xla_models/flex/__init__.py b/torchprime/torch_xla_models/flex/__init__.py index 1b3ac7dc..facc2eb8 100644 --- a/torchprime/torch_xla_models/flex/__init__.py +++ b/torchprime/torch_xla_models/flex/__init__.py @@ -1,5 +1,4 @@ # from .modeling_llama import LlamaForCausalLM from .modeling_qwen import Qwen3ForCausalLM - -__all__ = ["Qwen3ForCausalLM"] \ No newline at end of file +__all__ = ["Qwen3ForCausalLM"] diff --git a/torchprime/torch_xla_models/flex/attention.py b/torchprime/torch_xla_models/flex/attention.py index d525514f..be71ad56 100644 --- a/torchprime/torch_xla_models/flex/attention.py +++ b/torchprime/torch_xla_models/flex/attention.py @@ -2,24 +2,13 @@ from typing import Any import torch +import torch_xla.distributed.spmd as xs from torch import nn - -# Detect if we're on TPU (no CUDA available) vs GPU -IS_TPU = not torch.cuda.is_available() - -if IS_TPU: - # TPU environment - use original torch_xla imports - import torch_xla.debug.profiler as xp - import torch_xla.distributed.spmd as xs - from torch_xla.experimental.custom_kernel import FlashAttention, flash_attention - from torch_xla.experimental.splash_attention import ( - SplashAttentionConfig, - splash_attention, - ) -else: - # GPU environment - use PyTorch's native SDPA flash attention - from torch.nn.functional import scaled_dot_product_attention - from torch.nn.attention import SDPBackend, sdpa_kernel +from torch_xla.experimental.custom_kernel import FlashAttention, flash_attention +from torch_xla.experimental.splash_attention import ( + SplashAttentionConfig, + splash_attention, +) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -118,7 +107,7 @@ def _forward_tpu( query_states, key_states, value_states, - causal=False, # weiran: causal=False for bi-directional attention + causal=False, # weiran: causal=False for bi-directional attention partition_spec=self.partition_spec, ) case "default" | None: @@ -143,77 +132,9 @@ def _forward_tpu( ) attn_output = torch.matmul(attn_weights, value_states) case _: - raise NotImplementedError(f"Attention kernel {self.config.attention_kernel} is not supported yet") - - if attn_output.size() != (bsz, num_heads, q_len, head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" - f" {attn_output.size()}" - ) - return attn_output - - def _forward_gpu( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - ): - """GPU-optimized PyTorch implementation""" - if self.config.attention_kernel != "splash_attention": - num_key_value_groups = ( - self.config.num_attention_heads // self.config.num_key_value_heads - ) - key_states = repeat_kv(key_states, num_key_value_groups) - value_states = repeat_kv(value_states, num_key_value_groups) - - bsz, num_heads, q_len, head_dim = query_states.size() - # TODO: q, k dim unintentionally changed after the apply_rotary_pos_emb. Use - # v's dim temporarily to bypass shape assertion failure. Remove the - # following line after resolving - # https://github.com/AI-Hypercomputer/torchprime/issues/195. - head_dim = value_states.shape[-1] - - kv_seq_len = key_states.shape[-2] - - # Use SDPA with appropriate backend - - match self.config.attention_kernel: - case "splash_attention": - raise NotImplementedError("Splash Attention is not supported in GPU environment") - - case "flash_attention": - # Try to use flash attention backend, fallback to default if not available - try: - with sdpa_kernel(SDPBackend.FLASH_ATTENTION): - attn_output = scaled_dot_product_attention( - query_states, - key_states, - value_states, - dropout_p=self.config.attention_dropout if self.training else 0.0, - is_causal=False, # weiran: causal=False for bi-directional attention - ) - except (RuntimeError, NotImplementedError): - # Flash attention not available, use default backend - with sdpa_kernel(SDPBackend.MATH): - attn_output = scaled_dot_product_attention( - query_states, - key_states, - value_states, - dropout_p=self.config.attention_dropout if self.training else 0.0, - is_causal=False, # weiran: causal=False for bi-directional attention - ) - - case _: - # Default implementation - use math backend for compatibility - with sdpa_kernel(SDPBackend.MATH): - attn_output = scaled_dot_product_attention( - query_states, - key_states, - value_states, - dropout_p=self.config.attention_dropout if self.training else 0.0, - is_causal=False, # weiran: causal=False for bi-directional attention - ) + raise NotImplementedError( + f"Attention kernel {self.config.attention_kernel} is not supported yet" + ) if attn_output.size() != (bsz, num_heads, q_len, head_dim): raise ValueError( @@ -229,7 +150,4 @@ def forward( value_states: torch.Tensor, attention_mask: torch.Tensor | None = None, ): - if IS_TPU: - return self._forward_tpu(query_states, key_states, value_states, attention_mask) - else: - return self._forward_gpu(query_states, key_states, value_states, attention_mask) \ No newline at end of file + return self._forward_tpu(query_states, key_states, value_states, attention_mask) diff --git a/torchprime/torch_xla_models/flex/modeling_qwen.py b/torchprime/torch_xla_models/flex/modeling_qwen.py index 7ee22795..77377714 100644 --- a/torchprime/torch_xla_models/flex/modeling_qwen.py +++ b/torchprime/torch_xla_models/flex/modeling_qwen.py @@ -2,10 +2,8 @@ PyTorch Qwen3 model. Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py """ -from typing import Callable, Optional, Tuple, Union import torch -import torch.nn.functional as F from omegaconf import DictConfig from torch import nn from transformers.activations import ACT2FN @@ -13,13 +11,12 @@ from torchprime.layers.sequential import HomogeneousSequential from torchprime.rope.rope import RopeScaling, default_rope_frequencies -IS_TPU = not torch.cuda.is_available() -if IS_TPU: - from torchprime.torch_xla_models import offloading +from torchprime.torch_xla_models import offloading from torchprime.torch_xla_models.flex.attention import AttentionModule logger = logging.get_logger(__name__) + class Qwen3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -39,6 +36,7 @@ def forward(self, hidden_states): def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + class Qwen3MLP(nn.Module): def __init__(self, config): super().__init__() @@ -54,12 +52,14 @@ def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -86,6 +86,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -94,15 +95,18 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: torch.Tensor | None, scaling: float, dropout: float = 0.0, **kwargs, @@ -115,13 +119,18 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights + class Qwen3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -143,30 +152,31 @@ def __init__(self, config: DictConfig, layer_idx: int | None = None): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = getattr(config, "attention_dropout", 0.0) - # weiran: diffullama self.is_causal = False - + self.q_proj = nn.Linear( - self.hidden_size, - self.num_heads * self.head_dim, - bias=getattr(config, "attention_bias", False) + self.hidden_size, + self.num_heads * self.head_dim, + bias=getattr(config, "attention_bias", False), ) self.k_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=getattr(config, "attention_bias", False) + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=getattr(config, "attention_bias", False), ) self.v_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=getattr(config, "attention_bias", False) + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=getattr(config, "attention_bias", False), ) self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, "attention_bias", False) + self.num_heads * self.head_dim, + self.hidden_size, + bias=getattr(config, "attention_bias", False), ) self.q_norm = Qwen3RMSNorm(self.head_dim, eps=getattr(config, "rms_norm_eps", 1e-6)) self.k_norm = Qwen3RMSNorm(self.head_dim, eps=getattr(config, "rms_norm_eps", 1e-6)) - + # Handle sliding window - check if layer_types exists and if this layer should use sliding attention if not config.use_sliding_window: self.sliding_window = None @@ -176,7 +186,7 @@ def __init__(self, config: DictConfig, layer_idx: int | None = None): def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, ) -> torch.FloatTensor: @@ -189,7 +199,9 @@ def forward( # Apply q_norm and k_norm to the head dimension query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) # Apply normalization query_states = self.q_norm(query_states) @@ -260,7 +272,9 @@ def __init__(self, config: DictConfig, layer_idx: int): self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen3MLP(config) - self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)) + self.input_layernorm = Qwen3RMSNorm( + config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6) + ) self.post_attention_layernorm = Qwen3RMSNorm( config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6) ) @@ -270,7 +284,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor]| None = None, # necessary, but kept here for BC + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC ) -> torch.Tensor: """ Args: @@ -278,13 +293,13 @@ def forward( attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - """ + """ # This gives the `hidden_states` tensor a name so that we can layer specify # to offload this tensor to host RAM to save memory. This is not a standard # torch API because there is no such feature in PyTorch. Instead, the name # becomes node metadata during FX graph capture. - if IS_TPU: - hidden_states = offloading.offload_name(hidden_states, "decoder_input") + + hidden_states = offloading.offload_name(hidden_states, "decoder_input") residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -303,7 +318,7 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - + return hidden_states @@ -322,7 +337,9 @@ def __init__(self, config: DictConfig): self.padding_idx = None else: self.padding_idx = config.pad_token_id - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=self.padding_idx + ) # `HomogeneousSequential` is similar to `nn.Sequential` but can be compiled with # `scan` described in https://pytorch.org/xla/release/r2.6/features/scan.html. self.layers = HomogeneousSequential( @@ -331,10 +348,14 @@ def __init__(self, config: DictConfig): for layer_idx in range(config.num_hidden_layers) ] ) - self.norm = Qwen3RMSNorm(config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)) + self.norm = Qwen3RMSNorm( + config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6) + ) rope_scaling = config.get("rope_scaling", None) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) self.rope_theta = getattr(config, "rope_theta", 10000.0) if rope_scaling is not None: rope_scaling = RopeScaling(**rope_scaling) @@ -382,7 +403,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - + return hidden_states @@ -419,9 +440,9 @@ def forward( ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]: if not self.training: # haolin: during inference the masking is done when preprocessing the input, we don't need src_mask and noising - model_output = self.model(input_ids=input_ids, attention_mask=attention_mask) + model_output = self.model(input_ids=input_ids, attention_mask=attention_mask) hidden_states = model_output - logits = self.lm_head(hidden_states) # NOTE: we shift logits in generate() + logits = self.lm_head(hidden_states) # NOTE: we shift logits in generate() # logits = logits.float()[..., :-1, :].contiguous() # NOTE: we shift logits in inference_utils at inference time return logits, None @@ -432,38 +453,52 @@ def forward( sampling_eps = 1e-3 mask_token_id = self.mask_token_id loss_func = nn.CrossEntropyLoss(reduction="none") - batch_size, seq_len = input_ids.shape # input_ids: [batch_size, seq_len] - + batch_size, seq_len = input_ids.shape # input_ids: [batch_size, seq_len] + # Create maskable_mask based on training mode and src_mask # For SFT: src_mask is provided, maskable_mask = ~src_mask # For pretrain: src_mask is None, maskable_mask = all True - if src_mask is not None: # SFT + if src_mask is not None: # SFT maskable_mask = ~src_mask - else: # pretrain or midtrain - maskable_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device) + else: # pretrain or midtrain + maskable_mask = torch.ones_like( + input_ids, dtype=torch.bool, device=input_ids.device + ) prefix_probability = getattr(self.config, "prefix_probability", 0) truncate_probability = getattr(self.config, "truncate_probability", 0) # Generate random decisions for all batch items - apply_prefix = torch.rand(batch_size, device=input_ids.device) < prefix_probability + apply_prefix = ( + torch.rand(batch_size, device=input_ids.device) < prefix_probability + ) # Only apply truncation to rows that are NOT prefixed - apply_truncate = torch.rand(batch_size, device=input_ids.device) < truncate_probability + apply_truncate = ( + torch.rand(batch_size, device=input_ids.device) < truncate_probability + ) apply_truncate = apply_truncate & ~apply_prefix if prefix_probability > 0: maskable_mask = prefix_input_ids(input_ids, maskable_mask, apply_prefix) if truncate_probability > 0: - input_ids = truncate_input_ids(input_ids, apply_truncate, self.config.pad_token_id) - maskable_mask = maskable_mask & (input_ids != self.config.pad_token_id) # NOTE: necessary? + input_ids = truncate_input_ids( + input_ids, apply_truncate, self.config.pad_token_id + ) + maskable_mask = maskable_mask & ( + input_ids != self.config.pad_token_id + ) # NOTE: necessary? # add noise to input_ids - sigma = (1 - sampling_eps) * torch.rand(input_ids.shape[0], device=input_ids.device) + sampling_eps + sigma = (1 - sampling_eps) * torch.rand( + input_ids.shape[0], device=input_ids.device + ) + sampling_eps dsigma = torch.reciprocal(sigma) # Sample mask block size mask_block_sizes = getattr(self.config, "mask_block_sizes", None) block_masking_probability = getattr(self.config, "block_masking_probability", 0) if block_masking_probability > 0 and mask_block_sizes is not None: - mask_block_size = mask_block_sizes[torch.randint(0, len(mask_block_sizes), (1,)).item()] + mask_block_size = mask_block_sizes[ + torch.randint(0, len(mask_block_sizes), (1,)).item() + ] else: mask_block_size = 1 @@ -472,7 +507,7 @@ def forward( sigma[:, None], maskable_mask=maskable_mask, mask_token_id=mask_token_id, - mask_block_size=mask_block_size + mask_block_size=mask_block_size, ) loss_mask = noisy_input_ids == mask_token_id @@ -491,7 +526,7 @@ def forward( # loss: [bs, seq_len-1] loss = loss_func( logits.reshape(-1, logits.shape[-1]), target_ids.reshape(-1) - ).reshape(target_ids.shape[0],-1) + ).reshape(target_ids.shape[0], -1) loss = loss.masked_fill(~loss_mask, 0) # weiran: divide by the number of tokens in the sequence instead of the number of masked tokens # justification is dsigma already accounts for the number of masked tokens @@ -501,9 +536,7 @@ def forward( return logits, loss -def transition( - x_0, sigma, maskable_mask, mask_token_id, mask_block_size: int = 1 -): +def transition(x_0, sigma, maskable_mask, mask_token_id, mask_block_size: int = 1): """Apply masking to input tokens. If mask_block_size > 1, use block masking for all rows.""" if mask_block_size == 1: @@ -518,61 +551,69 @@ def transition( def block_masking(x_0, sigma, maskable_mask, mask_token_id, mask_block_size): - """ - XLA-compatible block masking applied uniformly to all rows in the batch. - Uses efficient tensor operations to avoid dynamic loops. - """ - batch_size, seq_len = x_0.shape - - if seq_len < mask_block_size: - return x_0 - - # Calculate number of possible block positions - num_windows = seq_len - mask_block_size + 1 - - # Create all possible block positions: [num_windows, mask_block_size] - window_starts = torch.arange(num_windows, device=x_0.device) - block_offsets = torch.arange(mask_block_size, device=x_0.device) - all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0) - - # Check which blocks are fully maskable: [batch_size, num_windows] - maskable_blocks = maskable_mask.unsqueeze(1).expand(-1, num_windows, -1).gather( - 2, all_positions.unsqueeze(0).expand(batch_size, -1, -1) - ) - fully_maskable = maskable_blocks.all(dim=2) - - # Determine which blocks should be masked: (batch_size, num_windows) - effective_sigma = 1 - (1-sigma)**(1/mask_block_size) # NOTE: since we mask with blocks, we need to scale sigma by block size - should_mask = (torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma) & fully_maskable - - # Create final mask using simple broadcasting (fully XLA-compatible) - # For each position in the sequence, check if it's part of any masked block - position_indices = torch.arange(seq_len, device=x_0.device) # [seq_len] - - # Check for each position if it falls within any masked block - # position_indices: [seq_len] -> [1, 1, seq_len] - # all_positions: [num_windows, mask_block_size] -> [1, num_windows, mask_block_size] - # should_mask: [batch_size, num_windows] -> [batch_size, num_windows, 1] - - position_indices = position_indices.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len] - all_positions = all_positions.unsqueeze(0) # [1, num_windows, mask_block_size] - should_mask = should_mask.unsqueeze(2) # [batch_size, num_windows, 1] - - # Check if each position matches any of the positions in masked blocks - # [1, 1, seq_len] == [1, num_windows, mask_block_size] -> [1, num_windows, seq_len] - position_matches = (position_indices == all_positions.unsqueeze(3)).any(dim=2) # [1, num_windows, seq_len] - - # Apply should_mask to get final positions to mask - # [batch_size, num_windows, 1] & [1, num_windows, seq_len] -> [batch_size, num_windows, seq_len] - should_mask_positions = should_mask & position_matches - - # Reduce over windows: if any window masks this position, mask it - final_mask = should_mask_positions.any(dim=1) # [batch_size, seq_len] - - # Apply the mask - result = torch.where(final_mask, mask_token_id, x_0) - - return result + """ + XLA-compatible block masking applied uniformly to all rows in the batch. + Uses efficient tensor operations to avoid dynamic loops. + """ + batch_size, seq_len = x_0.shape + + if seq_len < mask_block_size: + return x_0 + + # Calculate number of possible block positions + num_windows = seq_len - mask_block_size + 1 + + # Create all possible block positions: [num_windows, mask_block_size] + window_starts = torch.arange(num_windows, device=x_0.device) + block_offsets = torch.arange(mask_block_size, device=x_0.device) + all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0) + + # Check which blocks are fully maskable: [batch_size, num_windows] + maskable_blocks = ( + maskable_mask.unsqueeze(1) + .expand(-1, num_windows, -1) + .gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1)) + ) + fully_maskable = maskable_blocks.all(dim=2) + + # Determine which blocks should be masked: (batch_size, num_windows) + effective_sigma = 1 - (1 - sigma) ** ( + 1 / mask_block_size + ) # NOTE: since we mask with blocks, we need to scale sigma by block size + should_mask = ( + torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma + ) & fully_maskable + + # Create final mask using simple broadcasting (fully XLA-compatible) + # For each position in the sequence, check if it's part of any masked block + position_indices = torch.arange(seq_len, device=x_0.device) # [seq_len] + + # Check for each position if it falls within any masked block + # position_indices: [seq_len] -> [1, 1, seq_len] + # all_positions: [num_windows, mask_block_size] -> [1, num_windows, mask_block_size] + # should_mask: [batch_size, num_windows] -> [batch_size, num_windows, 1] + + position_indices = position_indices.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len] + all_positions = all_positions.unsqueeze(0) # [1, num_windows, mask_block_size] + should_mask = should_mask.unsqueeze(2) # [batch_size, num_windows, 1] + + # Check if each position matches any of the positions in masked blocks + # [1, 1, seq_len] == [1, num_windows, mask_block_size] -> [1, num_windows, seq_len] + position_matches = (position_indices == all_positions.unsqueeze(3)).any( + dim=2 + ) # [1, num_windows, seq_len] + + # Apply should_mask to get final positions to mask + # [batch_size, num_windows, 1] & [1, num_windows, seq_len] -> [batch_size, num_windows, seq_len] + should_mask_positions = should_mask & position_matches + + # Reduce over windows: if any window masks this position, mask it + final_mask = should_mask_positions.any(dim=1) # [batch_size, seq_len] + + # Apply the mask + result = torch.where(final_mask, mask_token_id, x_0) + + return result def prefix_input_ids(input_ids, maskable_mask, apply_prefix): @@ -581,7 +622,9 @@ def prefix_input_ids(input_ids, maskable_mask, apply_prefix): # Generate random prefix lengths for all batch items prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) # Create position indices: [1, seq_len] - position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) # [1, seq_len] + position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze( + 0 + ) # [1, seq_len] # Create prefix mask: True where position < prefix_length prefix_mask = position_indices < prefix_lengths.unsqueeze(1) # [batch_size, seq_len] # Apply prefix masking: set to False where we should apply prefix masking @@ -595,9 +638,15 @@ def truncate_input_ids(input_ids, apply_truncate, pad_token_id): # Generate random truncation positions for all batch items truncate_positions = torch.randint(1, seq_len, (batch_size,), device=input_ids.device) # Create position indices: [1, seq_len] - position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) # [1, seq_len] + position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze( + 0 + ) # [1, seq_len] # Create truncate mask: True where position >= truncate_position - truncate_mask = position_indices >= truncate_positions.unsqueeze(1) # [batch_size, seq_len] + truncate_mask = position_indices >= truncate_positions.unsqueeze( + 1 + ) # [batch_size, seq_len] # Apply truncation: fill with pad token where we should truncate - input_ids = torch.where(apply_truncate.unsqueeze(1) & truncate_mask, pad_token_id, input_ids) - return input_ids \ No newline at end of file + input_ids = torch.where( + apply_truncate.unsqueeze(1) & truncate_mask, pad_token_id, input_ids + ) + return input_ids diff --git a/torchprime/torch_xla_models/flex/qwen_hf.py b/torchprime/torch_xla_models/flex/qwen_hf.py new file mode 100644 index 00000000..2d46cb78 --- /dev/null +++ b/torchprime/torch_xla_models/flex/qwen_hf.py @@ -0,0 +1,600 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_qwen3 import Qwen3Config + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Qwen3RMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3RMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose( + 1, 2 + ) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose( + 1, 2 + ) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.attention_type = config.layer_types[layer_idx] + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Qwen3PreTrainedModel(PreTrainedModel): + config: Qwen3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen3DecoderLayer, + "attentions": Qwen3Attention, + } + + +class Qwen3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( + 1, 2 + ) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Qwen3Model(Qwen3PreTrainedModel): + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask( + **mask_kwargs + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +@auto_docstring +class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3ForCausalLM + + >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen3ForSequenceClassification( + GenericForSequenceClassification, Qwen3PreTrainedModel +): + pass + + +class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel): + pass + + +class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel): + base_model_prefix = ( + "transformer" # For BC, where `transformer` was used instead of `model` + ) + + +__all__ = [ + "Qwen3ForCausalLM", + "Qwen3ForQuestionAnswering", + "Qwen3PreTrainedModel", + "Qwen3Model", + "Qwen3ForSequenceClassification", + "Qwen3ForTokenClassification", +]