diff --git a/autoparallel/_testing/models/llama3.py b/autoparallel/_testing/models/llama3.py new file mode 100644 index 00000000..9d349e1a --- /dev/null +++ b/autoparallel/_testing/models/llama3.py @@ -0,0 +1,546 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.attention import SDPBackend, sdpa_kernel + + +def has_cuda_capability(major: int, minor: int) -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + major, + minor, + ) + + +class ScaledDotProductAttention(torch.nn.Module): + backends: ClassVar[list[SDPBackend]] = [] + + def __init__(self, attn_mask_type: str) -> None: + super().__init__() + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + + ScaledDotProductAttention._init_backend() + + @classmethod + def _init_backend(cls) -> None: + if cls.backends: + return + + # Add CuDNN on B200 w/ highest priority + cls.backends = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + if has_cuda_capability(10, 0): + cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + assert self.backends, "SDPA Backends should not be empty." + with sdpa_kernel(self.backends, set_priority=True): + return F.scaled_dot_product_attention(q, k, v, is_causal=True) + + +def build_attention( + use_flex_attn: bool, attn_mask_type: str, fixed_block_size: Optional[int] = None +): + if use_flex_attn: + raise NotImplementedError() + # return FlexAttention(attn_mask_type, fixed_block_size) + else: + if fixed_block_size is not None: + raise ValueError( + "TorchTitan with SDPA currently does not support fixed_block_size." + ) + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + return ScaledDotProductAttention(attn_mask_type) + + +@dataclass +class TransformerModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = 64000 # -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + + use_flex_attn: bool = False + attn_mask_type: str = "causal" + eos_id: int = 0 + + def update_from_config(self, job_config, tokenizer) -> None: + self.vocab_size = tokenizer.n_words + self.max_seq_len = job_config.training.seq_len + self.eos_id = tokenizer.eos_id + + if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn: + raise ValueError( + "FlexAttention is not compatible with selective AC yet. " + "See https://github.com/pytorch/pytorch/issues/147879" + ) + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise ValueError( + "FlexAttention is not compatible with CP yet. " + "We are still working on this." + ) + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + nparams = sum(p.numel() for p in model.parameters()) + nparams_embedding = sum( + sum(p.numel() for p in m.parameters()) + for m in model.children() + if isinstance(m, nn.Embedding) + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t + + return nparams, num_flops_per_token + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + # TODO: uncomment + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + output = self.sdpa(xq, xk, xv) + + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + return + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module): + """ + Transformer Module + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + def init_weights( + self, + buffer_device: Optional[torch.device] = None, + ): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + buffer_device = buffer_device or self.freqs_cis.device # type: ignore + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() # type: ignore + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor, input_batch: Optional[torch.Tensor] = None): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + input_batch (torch.Tensor): The input batch read from the dataloader. + This will always be the input batch regardless of the pipeline stage. + This field is required for non-first PP stages to perform document + masking attention (to analyze the boundary of the document). + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + return output diff --git a/examples/example_llama3.py b/examples/example_llama3.py index 97dcc4b5..eeff9ffb 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -4,558 +4,15 @@ # LICENSE file in the root directory of this source tree. import time -from dataclasses import dataclass -from typing import ClassVar import torch -import torch.nn.functional as F -from torch import nn from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Partial, Replicate, Shard -from torch.nn.attention import SDPBackend, sdpa_kernel from torch.testing._internal.distributed.fake_pg import FakeStore +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel - -def has_cuda_capability(major: int, minor: int) -> bool: - return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( - major, - minor, - ) - - -class ScaledDotProductAttention(torch.nn.Module): - backends: ClassVar[list[SDPBackend]] = [] - - def __init__(self, attn_mask_type: str) -> None: - super().__init__() - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - - ScaledDotProductAttention._init_backend() - - @classmethod - def _init_backend(cls) -> None: - if cls.backends: - return - - # Add CuDNN on B200 w/ highest priority - cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, - ] - if has_cuda_capability(10, 0): - cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) - - def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor - ) -> torch.Tensor: - assert self.backends, "SDPA Backends should not be empty." - with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True) - - -def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None -): - if use_flex_attn: - raise NotImplementedError() - # return FlexAttention(attn_mask_type, fixed_block_size) - else: - if fixed_block_size is not None: - raise ValueError( - "TorchTitan with SDPA currently does not support fixed_block_size." - ) - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - return ScaledDotProductAttention(attn_mask_type) - - -@dataclass -class TransformerModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: int | None = None - vocab_size: int = 64000 # -1 # defined later by tokenizer - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: float | None = None - norm_eps: float = 1e-5 - rope_theta: float = 10000 - - max_seq_len: int = 2048 - # If `True`, then each transformer block init uses its layer ID, and if - # `False`, each uses the total number of transformer blocks - depth_init: bool = True - - use_flex_attn: bool = False - attn_mask_type: str = "causal" - eos_id: int = 0 - - def update_from_config(self, job_config, tokenizer) -> None: - self.vocab_size = tokenizer.n_words - self.max_seq_len = job_config.training.seq_len - self.eos_id = tokenizer.eos_id - - if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn: - raise ValueError( - "FlexAttention is not compatible with selective AC yet. " - "See https://github.com/pytorch/pytorch/issues/147879" - ) - - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise ValueError( - "FlexAttention is not compatible with CP yet. " - "We are still working on this." - ) - - def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: - nparams = sum(p.numel() for p in model.parameters()) - nparams_embedding = sum( - sum(p.numel() for p in m.parameters()) - for m in model.children() - if isinstance(m, nn.Embedding) - ) - - l, h, q, t = ( - self.n_layers, - self.n_heads, - self.dim // self.n_heads, - seq_len, - ) - # Reasoning behind the factor of 12 for the self-attention part of the formula: - # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) - # 2. the flash attention does 1 more matmul recomputation in the backward - # but recomputation should not be counted in calculating MFU (+0) - # 3. each matmul performs 1 multiplication and 1 addition (*2) - # 4. we follow the convention and do not account for sparsity in causal attention - num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t - - return nparams, num_flops_per_token - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - freqs = torch.outer(t, freqs).float() - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """ - Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), - and the first seqlen elements will be sliced, but dim must match x. - - Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - - Returns: - torch.Tensor: Reshaped frequency tensor. - """ - ndim = x.ndim - assert ndim > 1 - seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. - - This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided - frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are - returned as real tensors. - - Args: - xq (torch.Tensor): Query tensor to apply rotary embeddings. - xk (torch.Tensor): Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. - - Returns: - tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - torch.unsqueeze(x, dim=3) - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -class Attention(nn.Module): - """ - Multi-head attention module. - - Args: - model_args (TransformerModelArgs): Model configuration arguments. - - Attributes: - n_kv_heads (int): Number of key and value heads. - n_heads (int): Number of query heads. - n_rep (int): Number of repetitions for local heads. - head_dim (int): Dimension size of each attention head. - wq (Linear): Linear transformation for queries. - wk (Linear): Linear transformation for keys. - wv (Linear): Linear transformation for values. - wo (Linear): Linear transformation for output. - - """ - - def __init__(self, model_args: TransformerModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.n_kv_heads = ( - model_args.n_heads - if model_args.n_kv_heads is None - else model_args.n_kv_heads - ) - self.n_rep = self.n_heads // self.n_kv_heads - self.head_dim = model_args.dim // model_args.n_heads - - self.wq = nn.Linear( - model_args.dim, model_args.n_heads * self.head_dim, bias=False - ) - self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear( - model_args.n_heads * self.head_dim, model_args.dim, bias=False - ) - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) - - def init_weights(self, init_std: float): - for linear in (self.wq, self.wk, self.wv): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - ): - """ - Forward pass of the attention module. - - Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed frequency tensor. - - Returns: - torch.Tensor: Output tensor after attention. - - """ - - bs, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual - # local heads from sizes of xq, xk, and xv as TP may have sharded them - # after the above linear ops. - xq = xq.view(bs, seqlen, -1, self.head_dim) - xk = xk.view(bs, seqlen, -1, self.head_dim) - xv = xv.view(bs, seqlen, -1, self.head_dim) - - # TODO: uncomment - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - - output = self.sdpa(xq, xk, xv) - - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bs, seqlen, -1) - return self.wo(output) - - -class FeedForward(nn.Module): - """ - FeedForward module - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. - - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ - - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: float | None, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - -class TransformerBlock(nn.Module): - """ - TransformerBlock Module - - Args: - layer_id (int): Identifier for the layer. - model_args (TransformerModelArgs): Model configuration arguments. - - Attributes: - n_heads (int): Number of attention heads. - dim (int): Dimension size of the model. - head_dim (int): Dimension size of each attention head. - attention (Attention): Attention module. - feed_forward (FeedForward): FeedForward module. - layer_id (int): Identifier for the layer. - attention_norm (RMSNorm): Layer normalization for attention output. - ffn_norm (RMSNorm): Layer normalization for feedforward output. - - """ - - def __init__(self, layer_id: int, model_args: TransformerModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.dim = model_args.dim - - self.attention = Attention(model_args) - self.feed_forward = FeedForward( - dim=model_args.dim, - hidden_dim=4 * model_args.dim, - multiple_of=model_args.multiple_of, - ffn_dim_multiplier=model_args.ffn_dim_multiplier, - ) - self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - - if model_args.depth_init: - self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 - else: - self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - ): - """ - Perform a forward pass through the TransformerBlock. - - Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - - Returns: - torch.Tensor: Output tensor after applying attention and feedforward layers. - - """ - h = x + self.attention(self.attention_norm(x), freqs_cis) - out = h + self.feed_forward(self.ffn_norm(h)) - return out - - def init_weights(self): - return - for norm in (self.attention_norm, self.ffn_norm): - norm.reset_parameters() - self.attention.init_weights(self.weight_init_std) - self.feed_forward.init_weights(self.weight_init_std) - - -class Transformer(nn.Module): - """ - Transformer Module - - Args: - model_args (TransformerModelArgs): Model configuration arguments. - - Attributes: - model_args (TransformerModelArgs): Model configuration arguments. - vocab_size (int): Vocabulary size. - n_layers (int): Number of layers in the model. - tok_embeddings (ParallelEmbedding): Token embeddings. - layers (torch.nn.ModuleList): List of Transformer blocks. - norm (RMSNorm): Layer normalization for the model output. - output (ColumnParallelLinear): Linear layer for final output. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - - """ - - def __init__(self, model_args: TransformerModelArgs): - super().__init__() - self.model_args = model_args - self.vocab_size = model_args.vocab_size - self.n_layers = model_args.n_layers - self.eos_id = model_args.eos_id - - self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) - - # TODO persistent should be set to false, since this buffer can be recomputed. - # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, - # compile or pipeline-tracer will not correctly handle non-persistent buffers, - # so we need to fix that. (2) if we initialize pipeline-parallel models from - # a seed checkpoint rather than calling init_weights, we need freqs_cis to be - # initialized by the checkpoint, or we need to add a separate initializer for - # just the non-persistent buffers that is called after loading checkpoints. - self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) - - self.layers = torch.nn.ModuleDict() - for layer_id in range(model_args.n_layers): - self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) - self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) - - def init_weights( - self, - buffer_device: torch.device | None = None, - ): - """ - [Note: On ``init_weights`` vs. ``reset_parameters``] - Modules may define ``reset_parameters`` to initialize parameter values. - ``reset_parameters`` is meant to only initialize directly owned - parameters/buffers, not those of their child modules, and it can be - used to give the initial values for these tensors. - Separately, users may want custom initialization for their modules, - different from that in ``reset_parameters``. For this, we define - ``init_weights``. We only call it in the constructor of this - ``Transformer`` root module to avoid reinitializing tensors. - """ - buffer_device = buffer_device or self.freqs_cis.device # type: ignore - with torch.device(buffer_device): - self.freqs_cis = self._precompute_freqs_cis() - if self.tok_embeddings is not None: - nn.init.normal_(self.tok_embeddings.weight) - for layer in self.layers.values(): - if layer is not None: - layer.init_weights() - if self.norm is not None: - self.norm.reset_parameters() - final_out_std = self.model_args.dim**-0.5 - cutoff_factor = 3 - if self.output is not None: - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=final_out_std, - a=-cutoff_factor * final_out_std, - b=cutoff_factor * final_out_std, - ) - - def _precompute_freqs_cis(self) -> torch.Tensor: - return precompute_freqs_cis( - self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # TODO: explain in docs/composability.md why we removed the 2x - # relaxing in our CP enablement PR - self.model_args.max_seq_len, - self.model_args.rope_theta, - ) - - def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): - """ - Perform a forward pass through the Transformer model. - - Args: - tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. - If pipeline parallelism is enabled, this will be the input token indices - for the ranks on the first pipeline stage. This will be the activation of the - previous pipeline stage if the current rank is not on the first stage. - input_batch (torch.Tensor): The input batch read from the dataloader. - This will always be the input batch regardless of the pipeline stage. - This field is required for non-first PP stages to perform document - masking attention (to analyze the boundary of the document). - - Returns: - torch.Tensor: Output logits after applying the Transformer model. - - """ - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - - for layer in self.layers.values(): - h = layer(h, self.freqs_cis) - - h = self.norm(h) if self.norm else h - output = self.output(h) if self.output else h - return output - - -# ============================================================== -# AutoParallel code starts here -# ============================================================== - world_size = 64 fake_store = FakeStore()