From 8b97c9615726684ba1398431df0ca03f88639076 Mon Sep 17 00:00:00 2001 From: Yufeng Xu Date: Thu, 9 Oct 2025 23:00:08 -0400 Subject: [PATCH] update requirements to be compatible with flash-attn-2 --- README.md | 8 ++++-- llmshearing/models/composer_llama.py | 36 +++++++++++++++++------- llmshearing/models/metrics.py | 12 +++++++- llmshearing/utils/test_composer_hf_eq.py | 7 ++++- 4 files changed, 48 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 18ac60e..f982dc1 100644 --- a/README.md +++ b/README.md @@ -54,14 +54,16 @@ This codebase is built based on MosaicML's amazing [Composer package](https://gi ## Install Requirements **Step 1**: To get started with this repository, you'll need to follow these installation steps. Before proceeding, make sure you have [Pytorch](https://pytorch.org/get-started/previous-versions/) and [Flash Attention](https://github.com/Dao-AILab/flash-attention) installed. You can do this via pip using the following commands: ``` -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 -pip install flash-attn==1.0.3.post +# pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 +# pip install flash-attn==1.0.3.post +pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121 --index-url https://download.pytorch.org/whl/cu121 +pip install "flash-attn==2.3.2" ``` Please note that Flash Attention version 2 is not currently supported and may require manual modifications to the model file. +Update: the `flash-attn` version and corresponding interface in model are updated. Now it's compatible with Flash Attention 2. **Step 2**: Then install the rest of the required packages: ``` -cd llmshearing pip install -r requirement.txt ``` diff --git a/llmshearing/models/composer_llama.py b/llmshearing/models/composer_llama.py index 68421de..973a1db 100644 --- a/llmshearing/models/composer_llama.py +++ b/llmshearing/models/composer_llama.py @@ -787,6 +787,7 @@ def flash_attn_fn( try: from flash_attn import bert_padding # type: ignore from flash_attn import flash_attn_interface # type: ignore + from flash_attn import flash_attn_func, flash_attn_varlen_func # for flash-attn-2 except ImportError as e: raise e @@ -815,18 +816,33 @@ def flash_attn_fn( dropout_p = dropout_p if training else 0.0 - output_unpad = flash_attn_interface.flash_attn_unpadded_func( - query_unpad, - key_unpad, - value_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, + # output_unpad = flash_attn_interface.flash_attn_unpadded_func( + # query_unpad, + # key_unpad, + # value_unpad, + # cu_seqlens_q, + # cu_seqlens_k, + # max_seqlen_q, + # max_seqlen_k, + # dropout_p, + # softmax_scale=softmax_scale, + # causal=is_causal, + # return_attn_probs=needs_weights) + + # the flash-attn-2 interface + output_unpad = flash_attn_varlen_func( + q=query_unpad, + k=key_unpad, + v=value_unpad, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, softmax_scale=softmax_scale, causal=is_causal, - return_attn_probs=needs_weights) + return_attn_probs=needs_weights, + ) if head_z is not None: output_unpad = output_unpad * head_z # 1 * h * 1 diff --git a/llmshearing/models/metrics.py b/llmshearing/models/metrics.py index 5ac36a1..ca9b96d 100644 --- a/llmshearing/models/metrics.py +++ b/llmshearing/models/metrics.py @@ -23,13 +23,23 @@ def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None: target = target.view(-1) logits = logits.view(target.shape[0], -1) - losses = self.loss_fn(logits, target) + # losses = self.loss_fn(logits, target) total_items = (target != self.ignore_index).sum() + if total_items.item() == 0: + return # 👈 skip update to avoid NaN + + losses = self.loss_fn(logits, target) self.total_items += total_items #type: ignore (third-party) # accumulate loss over all batches self.sum_loss += losses.to(torch.float32) + + # override base class, to avoid zero division + def compute(self) -> torch.Tensor: + if self.total_items == 0: + return torch.tensor(0.0, dtype=torch.float32, device=self.sum_loss.device) + return self.sum_loss / self.total_items class DomainCount(Metric): diff --git a/llmshearing/utils/test_composer_hf_eq.py b/llmshearing/utils/test_composer_hf_eq.py index a354816..352bc85 100644 --- a/llmshearing/utils/test_composer_hf_eq.py +++ b/llmshearing/utils/test_composer_hf_eq.py @@ -25,7 +25,12 @@ def construct_example_cfg(model_size, path=None, add_l0_module=False): cfg = om.create({"name": "mosaic_llama_65b", "path": path,"init_device": "cpu", "d_model": 8192, "n_heads": 64, "n_layers": 80, "intermediate_size": 22016}) # add default values - cfg = om.merge(cfg, om.create({"max_seq_len": 4096, "vocab_size": 32000, "init_std": 0.02, "attn_pdrop": 0.0, "resid_pdrop": 0.0, "emb_pdrop": 0.0, "attn_impl": "flash", "rms_norm_eps": 1e-5})) + # cfg = om.merge(cfg, om.create({"max_seq_len": 4096, "vocab_size": 32000, "init_std": 0.02, "attn_pdrop": 0.0, "resid_pdrop": 0.0, "emb_pdrop": 0.0, "attn_impl": "flash", "rms_norm_eps": 1e-5})) + # order reversed. Use cfg to override default, instead of the opposite + cfg = om.merge( + om.create({"max_seq_len": 4096, "vocab_size": 32000, "init_std": 0.02, "attn_pdrop": 0.0, "resid_pdrop": 0.0, "emb_pdrop": 0.0, "attn_impl": "flash", "rms_norm_eps": 1e-5}), + cfg + ) if add_l0_module: cfg["l0_module"] = {"start_sparsity": 0, "target_sparsity": 0.6, "pruning_modules": ["head", "head_layer", "mlp", "intermediate", "hidden"], "lagrangian_warmup_steps": "320ba"} return cfg