-
Notifications
You must be signed in to change notification settings - Fork 247
[OMNIML-2850] [3/n] Adds sparse attention calibration #538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kaix-nv
wants to merge
9
commits into
main
Choose a base branch
from
kaix/sparse_attention_calibration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
1883870
Add sparse attention integration to llm_eval
kaix-nv 6486db4
Add sparse attention calibration for the decode phase
kaix-nv 5fd37f0
Add hf unified checkpoint export for sparse attention
kaix-nv 8908b1a
Address feedbacks
kaix-nv d984cd8
update default threshold_trials
kaix-nv 42a0499
Update sparse attention config
kaix-nv 08bbc62
Move the data folder under example
kaix-nv da96f1b
Implement Inverse Power calibration for sparse attention
kaix-nv a5136e8
Switch to exponential model for fitting from inverse power
kaix-nv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Utilities for sparse attention integration with llm_eval.""" | ||
|
|
||
| import modelopt.torch.sparsity.attention_sparsity as mtsa | ||
|
|
||
|
|
||
| def _extract_model(model_obj): | ||
| """Extract actual model from wrapper (HFLM or EvalModel).""" | ||
| if hasattr(model_obj, "gpt2"): | ||
| return model_obj.gpt2 | ||
| elif hasattr(model_obj, "model"): | ||
| return model_obj.model | ||
| else: | ||
| return model_obj | ||
|
|
||
|
|
||
| def sparsify_model( | ||
| model, | ||
| sparse_cfg: str, | ||
| backend=None, | ||
| ): | ||
| """Apply sparse attention to model with optional RULER calibration. | ||
|
|
||
| Args: | ||
| model: Model wrapper (HFLM or EvalModel) or raw model | ||
| sparse_cfg: Sparse attention config name or dict | ||
| backend: Backend to use (optional, overrides config backend) | ||
|
|
||
| Returns: | ||
| The model with sparse attention applied | ||
|
|
||
| Note: | ||
| Calibration is automatically triggered if the config contains a 'calibration' field. | ||
| The calibration will auto-generate RULER dataset from the model's tokenizer. | ||
| """ | ||
| # Extract actual model | ||
| net = _extract_model(model) | ||
|
|
||
| # Resolve config | ||
| if isinstance(sparse_cfg, str): | ||
| # Get config from mtsa module (e.g., SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_DEFAULT) | ||
| mtsa_cfg = getattr(mtsa, sparse_cfg, None) | ||
| if mtsa_cfg is None: | ||
| raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}.") | ||
| else: | ||
| mtsa_cfg = sparse_cfg | ||
|
|
||
| # Override backend if specified | ||
| if backend: | ||
| if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg: | ||
| modified_sparse_cfg = {} | ||
| for pattern, cfg in mtsa_cfg["sparse_cfg"].items(): | ||
| modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg | ||
| if isinstance(modified_cfg, dict): | ||
| modified_cfg["backend"] = backend | ||
| modified_sparse_cfg[pattern] = modified_cfg | ||
| mtsa_cfg = {"sparse_cfg": modified_sparse_cfg} | ||
|
|
||
| # Apply sparsification | ||
| print(f"\nApplying sparse attention with config: {sparse_cfg}") | ||
| mtsa.sparsify(net, mtsa_cfg) | ||
| print("Sparse attention applied successfully!") | ||
|
|
||
| return model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # Data directory for calibration | ||
| data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| # Attention Sparsity for HuggingFace Models | ||
|
|
||
| In this tutorial, we demonstrate how to use NVIDIA TensorRT Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. | ||
|
|
||
| ## Getting Started | ||
|
|
||
| ### Quick Example | ||
|
|
||
| ```python | ||
| import modelopt.torch.sparsity.attention_sparsity as mtsa | ||
| from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT | ||
|
|
||
| # Load your model | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| "Qwen/Qwen3-8B", | ||
| attn_implementation="eager", # Required for sparse attention | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| # Apply sparse attention | ||
| model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) | ||
| ``` | ||
|
Comment on lines
+9
to
+22
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add missing imports in the quick example. The code example is missing necessary imports for 📝 Suggested fix ```python
+import torch
+from transformers import AutoModelForCausalLM
+
import modelopt.torch.sparsity.attention_sparsity as mtsa
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT🤖 Prompt for AI Agents |
||
|
|
||
| > [!Note] | ||
| > `attn_implementation="eager"` is required for sparse attention to work properly. Flash Attention 2 or SDPA would bypass the softmax patching needed for stats collection. | ||
|
|
||
| ## Configuration Options | ||
|
|
||
| Two pre-defined configurations are available: | ||
|
|
||
| ### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT) | ||
|
|
||
| Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths. | ||
|
|
||
| ```python | ||
| from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT | ||
|
|
||
| model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) | ||
| ``` | ||
|
|
||
| ### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB) | ||
|
|
||
| Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use. | ||
|
|
||
| ```python | ||
| from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB | ||
|
|
||
| model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB) | ||
| ``` | ||
|
|
||
| ## Prerequisites | ||
|
|
||
| ### Local Installation | ||
|
|
||
| For Hugging Face models, install Model Optimizer with `hf` dependencies using `pip` from [PyPI](https://pypi.org/project/nvidia-modelopt/) and install the requirements for the example: | ||
|
|
||
| ```bash | ||
| pip install nvidia-modelopt[hf] | ||
| ``` | ||
|
|
||
| ### Download RULER Calibration Data (Required for Calibration) | ||
|
|
||
| If using `SKIP_SOFTMAX_CALIB`, you need to download the RULER calibration dataset first: | ||
|
|
||
| ```bash | ||
| bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh | ||
| ``` | ||
|
|
||
| This downloads the Paul Graham essays dataset used for generating calibration samples. | ||
|
|
||
| ## Run Sparse Attention on HuggingFace Models | ||
|
|
||
| ### Basic Usage (Without Calibration) | ||
|
|
||
| Apply sparse attention with a fixed threshold: | ||
|
|
||
| ```bash | ||
| python examples/llm_sparsity/attention_sparsity/hf_sa.py \ | ||
| --pyt_ckpt_path Qwen/Qwen3-8B \ | ||
| --sparse_attn skip_softmax | ||
| ``` | ||
|
|
||
| ### With RULER Calibration | ||
|
|
||
| Apply sparse attention with calibrated thresholds for optimal sparsity: | ||
|
|
||
| ```bash | ||
| python examples/llm_sparsity/attention_sparsity/hf_sa.py \ | ||
| --pyt_ckpt_path Qwen/Qwen3-8B \ | ||
| --sparse_attn skip_softmax_calib | ||
| ``` | ||
|
|
||
| The calibration process: | ||
|
|
||
| 1. Generates RULER calibration samples | ||
| 2. Collects attention statistics during forward passes | ||
| 3. Determines optimal threshold scale factor for target sparsity ratio | ||
|
|
||
| ### Command Line Arguments | ||
|
|
||
| | Argument | Default | Description | | ||
| |----------|---------|-------------| | ||
| | `--pyt_ckpt_path` | Required | HuggingFace model path or name | | ||
| | `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` | | ||
| | `--backend` | `pytorch` | Backend: `pytorch` (only supported backend) | | ||
| | `--seq_len` | `2048` | Maximum sequence length for input prompts | | ||
| | `--export_dir` | `None` | Directory to export the sparsified model | | ||
|
|
||
| ## Output Comparison | ||
|
|
||
| The script automatically compares outputs before and after applying sparse attention: | ||
|
|
||
| 1. Loads a test sample from the NarrativeQA dataset | ||
| 2. Generates text before sparse attention is applied | ||
| 3. Applies sparse attention (with optional calibration) | ||
| 4. Generates text after sparse attention is applied | ||
| 5. Compares and displays both outputs | ||
|
|
||
| ## Export Model | ||
|
|
||
| Export the sparsified model to a HuggingFace checkpoint: | ||
|
|
||
| ```bash | ||
| python examples/llm_sparsity/attention_sparsity/hf_sa.py \ | ||
| --pyt_ckpt_path Qwen/Qwen3-8B \ | ||
| --sparse_attn skip_softmax_calib \ | ||
| --export_dir ./exported_sparse_model | ||
| ``` | ||
|
|
||
| The exported model can be loaded and used with standard HuggingFace APIs. | ||
|
|
||
| ## Custom Configuration | ||
|
|
||
| You can create custom sparse attention configurations: | ||
|
|
||
| ```python | ||
| custom_config = { | ||
| "sparse_cfg": { | ||
| "calibration": { # Optional: omit for fixed threshold | ||
| "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, # Target 50% sparsity | ||
| "samples": 128, # Number of calibration samples | ||
| "max_seqlen": 8192, # Maximum sequence length | ||
| # Optional: customize threshold trials for calibration | ||
| "threshold_trials": [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1], | ||
| }, | ||
| "*attn*": { # Pattern to match attention modules | ||
| "method": "flash_skip_softmax", | ||
| "threshold": {"prefill": 1e-3, "decode": 1e-4}, # Phase-specific thresholds (ignored if calibration is used) | ||
| "br": 128, # Flash Attention block rows | ||
| "bc": 128, # Flash Attention block columns | ||
| "backend": "pytorch", | ||
| "collect_stats": True, | ||
| "enable": True, | ||
| }, | ||
| "default": {"enable": False}, | ||
| }, | ||
| } | ||
|
|
||
| model = mtsa.sparsify(model, config=custom_config) | ||
| ``` | ||
|
|
||
| ## References | ||
|
|
||
| - [TensorRT Model Optimizer Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) | ||
| - [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as https://github.com/NVIDIA/Model-Optimizer/pull/538/files#r2646356349 and avoid repeated attention modification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed this check.