From 1883870acc1c2468d2f8f5d8aa42cdc88cd19470 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Sat, 13 Dec 2025 21:11:16 +0000 Subject: [PATCH 1/9] Add sparse attention integration to llm_eval Signed-off-by: Kai Xu --- .vscode/settings.json | 3 + examples/llm_eval/lm_eval_hf.py | 28 + examples/llm_eval/mmlu.py | 25 + examples/llm_eval/modeling.py | 5 + examples/llm_eval/sparse_attention_utils.py | 111 ++++ .../llm_sparsity/attention_sparsity/README.md | 161 +++++ .../llm_sparsity/attention_sparsity/hf_sa.py | 152 ++--- .../attention_sparsity/requirements.txt | 2 + .../llm_sparsity/weight_sparsity/finetune.py | 16 +- .../calibration/__init__.py | 26 + .../calibration/calibrate.py | 211 ++++++ .../calibration/calibrator.py | 312 +++++++++ .../attention_sparsity/calibration/dataset.py | 546 +++++++++++++++ .../calibration/download_ruler_data.sh | 50 ++ .../calibration/ruler_utils.py | 487 ++++++++++++++ .../sparsity/attention_sparsity/config.py | 197 +++++- .../sparsity/attention_sparsity/conversion.py | 43 ++ .../methods/flash_skip_softmax.py | 119 ++-- .../attention_sparsity/methods/registry.py | 47 +- .../attention_sparsity/model_sparsify.py | 52 +- .../attention_sparsity/sparse_attention.py | 46 +- .../attention_sparsity/stats_manager.py | 137 ++++ setup.py | 2 + .../torch/sparsity/sparse_attention_common.py | 10 +- tests/examples/llm_eval/test_llm_eval.py | 17 + .../test_attention_sparsity.py | 5 +- .../test_calibration_gpu.py | 388 +++++++++++ .../test_flash_skip_softmax.py | 57 +- .../test_sparse_attention_calibration.py | 623 ++++++++++++++++++ .../test_sparse_attention_config.py | 129 ++++ .../test_sparse_attention_conversion.py | 111 ++++ .../attention_sparsity/test_stats_manager.py | 334 ++++++++++ .../attention_sparsity/test_threshold_info.py | 270 ++++++++ 33 files changed, 4507 insertions(+), 215 deletions(-) create mode 100644 examples/llm_eval/sparse_attention_utils.py create mode 100644 examples/llm_sparsity/attention_sparsity/README.md create mode 100644 examples/llm_sparsity/attention_sparsity/requirements.txt create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py create mode 100755 modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/stats_manager.py create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 0e8465ad3..1cff4a791 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -45,4 +45,7 @@ ], "git.alwaysSignOff": true, "git.enableCommitSigning": true, + "cursorpyright.analysis.extraPaths": [ + "./tests/" + ], } diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 31103ff86..24dcb28f6 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -43,9 +43,11 @@ from lm_eval.api.model import T from lm_eval.models.huggingface import HFLM from quantization_utils import quantize_model +from sparse_attention_utils import sparsify_model import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: @@ -60,9 +62,20 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | calib_size = arg_dict.pop("calib_size", 512) compress = arg_dict.pop("compress", False) + # Sparse attention arguments + sparse_cfg = arg_dict.pop("sparse_cfg", None) + additional_config = {} if additional_config is None else additional_config additional_config = {k: v for k, v in additional_config.items() if v is not None} + # Force eager attention if sparse attention is requested + if sparse_cfg: + additional_config["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + # Enable automatic save/load of modelopt state huggingface checkpointing mto.enable_huggingface_checkpointing() @@ -91,6 +104,15 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | auto_quantize_checkpoint=auto_quantize_checkpoint, ) + if sparse_cfg: + if is_attn_sparsified(model_obj.model): + warnings.warn("Skipping sparse attention: model already has sparse attention applied.") + else: + sparsify_model( + model=model_obj, + sparse_cfg=sparse_cfg, + ) + return model_obj @@ -152,6 +174,11 @@ def setup_parser_with_modelopt_args(): action="store_true", help="Compress the model after quantization", ) + parser.add_argument( + "--sparse_cfg", + type=str, + help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)", + ) return parser @@ -177,6 +204,7 @@ def setup_parser_with_modelopt_args(): "calib_batch_size": args.calib_batch_size, "calib_size": args.calib_size, "compress": args.compress, + "sparse_cfg": args.sparse_cfg, } ) diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index ca244052b..0bf47fcd3 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -48,6 +48,7 @@ from fire import Fire from modeling import EvalModel, select_model from quantization_utils import MAX_SEQ_LEN, get_tokenizer, quantize_model +from sparse_attention_utils import sparsify_model from tqdm import tqdm try: @@ -56,6 +57,7 @@ LLM = None # type: ignore[misc] import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -230,6 +232,7 @@ def main( auto_quantize_method: str = "gradient", auto_quantize_score_size: int = 128, auto_quantize_checkpoint: str | None = None, + sparse_cfg: str | None = None, **kwargs, ): random.seed(RAND_SEED) @@ -266,6 +269,14 @@ def main( max_batch_size=1, ) else: + # Force eager attention if sparse attention is requested + if sparse_cfg: + kwargs["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + model = select_model( max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs ) @@ -289,6 +300,20 @@ def main( auto_quantize_checkpoint=auto_quantize_checkpoint, ) + # Apply sparse attention if requested + if sparse_cfg: + model.load() + + if is_attn_sparsified(model.model): + warnings.warn( + "Skipping sparse attention: model already has sparse attention applied." + ) + else: + sparsify_model( + model=model, + sparse_cfg=sparse_cfg, + ) + for subject in tqdm(subjects): dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[ :ntrain diff --git a/examples/llm_eval/modeling.py b/examples/llm_eval/modeling.py index 747b95d5b..d06d05560 100644 --- a/examples/llm_eval/modeling.py +++ b/examples/llm_eval/modeling.py @@ -179,6 +179,7 @@ class SeqToSeqModel(EvalModel): lora_path: str = "" device: str = "cuda" load_8bit: bool = False + attn_implementation: str | None = None def load(self): if self.model is None: @@ -188,6 +189,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args) print_gpu_utilization() if self.lora_path: @@ -241,6 +244,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForCausalLM.from_pretrained( self.model_path, trust_remote_code=True, **args ) diff --git a/examples/llm_eval/sparse_attention_utils.py b/examples/llm_eval/sparse_attention_utils.py new file mode 100644 index 000000000..dc7a1b14e --- /dev/null +++ b/examples/llm_eval/sparse_attention_utils.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for sparse attention integration with llm_eval.""" + +import modelopt.torch.sparsity.attention_sparsity as mtsa + +# Custom sparse attention configurations +CUSTOM_SPARSE_CONFIG = { + "SPARSE_CONSERVATIVE": { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 5e-4, "decode": 1e-5}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + }, + "default": {"enable": False}, + }, + }, + "SPARSE_AGGRESSIVE": { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 5e-3, "decode": 5e-4}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + }, + "default": {"enable": False}, + }, + }, +} + + +def _extract_model(model_obj): + """Extract actual model from wrapper (HFLM or EvalModel).""" + if hasattr(model_obj, "gpt2"): + return model_obj.gpt2 + elif hasattr(model_obj, "model"): + return model_obj.model + else: + return model_obj + + +def sparsify_model( + model, + sparse_cfg: str, + backend=None, +): + """Apply sparse attention to model with optional RULER calibration. + + Args: + model: Model wrapper (HFLM or EvalModel) or raw model + sparse_cfg: Sparse attention config name or dict + backend: Backend to use (optional, overrides config backend) + + Returns: + The model with sparse attention applied + + Note: + Calibration is automatically triggered if the config contains a 'calibration' field. + The calibration will auto-generate RULER dataset from the model's tokenizer. + """ + # Extract actual model + net = _extract_model(model) + + # Resolve config + if isinstance(sparse_cfg, str): + # Try custom configs first + mtsa_cfg = CUSTOM_SPARSE_CONFIG.get(sparse_cfg) + if mtsa_cfg is None: + # Try predefined configs + mtsa_cfg = getattr(mtsa, sparse_cfg, None) + if mtsa_cfg is None: + raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}") + else: + mtsa_cfg = sparse_cfg + + # Override backend if specified + if backend: + if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg: + modified_sparse_cfg = {} + for pattern, cfg in mtsa_cfg["sparse_cfg"].items(): + modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg + if isinstance(modified_cfg, dict): + modified_cfg["backend"] = backend + modified_sparse_cfg[pattern] = modified_cfg + mtsa_cfg = {"sparse_cfg": modified_sparse_cfg} + + # Apply sparsification + print(f"\nApplying sparse attention with config: {sparse_cfg}") + mtsa.sparsify(net, mtsa_cfg) + print("Sparse attention applied successfully!") + + return model diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md new file mode 100644 index 000000000..708947683 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -0,0 +1,161 @@ +# Attention Sparsity for HuggingFace Models + +In this tutorial, we demonstrate how to use NVIDIA TensorRT Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. + +## Getting Started + +### Quick Example + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +# Load your model +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, +) + +# Apply sparse attention +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +> [!Note] +> `attn_implementation="eager"` is required for sparse attention to work properly. Flash Attention 2 or SDPA would bypass the softmax patching needed for stats collection. + +## Configuration Options + +Two pre-defined configurations are available: + +### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT) + +Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB) + +Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB) +``` + +## Prerequisites + +### Install Requirements + +```bash +pip install -r requirements.txt +``` + +### Download RULER Calibration Data (Required for Calibration) + +If using `SKIP_SOFTMAX_CALIB`, you need to download the RULER calibration dataset first: + +```bash +bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh +``` + +This downloads the Paul Graham essays dataset used for generating calibration samples. + +## Run Sparse Attention on HuggingFace Models + +### Basic Usage (Without Calibration) + +Apply sparse attention with a fixed threshold: + +```bash +python examples/llm_sparsity/attention_sparsity/hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax +``` + +### With RULER Calibration + +Apply sparse attention with calibrated thresholds for optimal sparsity: + +```bash +python examples/llm_sparsity/attention_sparsity/hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax_calib +``` + +The calibration process: + +1. Generates RULER calibration samples +2. Collects attention statistics during forward passes +3. Determines optimal threshold scale factor for target sparsity ratio + +### Command Line Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--pyt_ckpt_path` | Required | HuggingFace model path or name | +| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` | +| `--backend` | `pytorch` | Backend: `pytorch` or `triton` | +| `--seq_len` | `2048` | Maximum sequence length for input prompts | +| `--export_dir` | `None` | Directory to export the sparsified model | + +## Output Comparison + +The script automatically compares outputs before and after applying sparse attention: + +1. Loads a test sample from the NarrativeQA dataset +2. Generates text before sparse attention is applied +3. Applies sparse attention (with optional calibration) +4. Generates text after sparse attention is applied +5. Compares and displays both outputs + +## Export Model + +Export the sparsified model to a HuggingFace checkpoint: + +```bash +python examples/llm_sparsity/attention_sparsity/hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax_calib \ + --export_dir ./exported_sparse_model +``` + +The exported model can be loaded and used with standard HuggingFace APIs. + +## Custom Configuration + +You can create custom sparse attention configurations: + +```python +custom_config = { + "sparse_cfg": { + "calibration": { # Optional: omit for fixed threshold + "target_sparse_ratio": 0.5, # Target 50% sparsity + "samples": 128, # Number of calibration samples + "max_seqlen": 8192, # Maximum sequence length + }, + "*attn*": { # Pattern to match attention modules + "method": "flash_skip_softmax", + "threshold": 1e-4, # Fixed threshold (ignored if calibration is used) + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + +model = mtsa.sparsify(model, config=custom_config) +``` + +## References + +- [TensorRT Model Optimizer Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) +- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER) diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 11564a4ec..29a2b53aa 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -28,9 +28,10 @@ import modelopt.torch.opt as mto import modelopt.torch.sparsity.attention_sparsity as mtsa from modelopt.torch.export import export_hf_checkpoint -from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig -from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT -from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_CALIB, + SKIP_SOFTMAX_DEFAULT, +) from modelopt.torch.utils.memory_monitor import launch_memory_monitor RAND_SEED = 1234 @@ -38,9 +39,10 @@ # Enable HuggingFace checkpointing support mto.enable_huggingface_checkpointing() -# You can define custom configurations or use the default +# Sparse attention configuration choices SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, + "skip_softmax_calib": SKIP_SOFTMAX_CALIB, } @@ -116,30 +118,23 @@ def truncate_text(text: str, tokenizer, max_length: int): return begin_text + " [...] " + end_text -def verify_outputs(model, tokenizer, args): - """Compare outputs between baseline and sparse attention models.""" - # Update seq_len to match calibration max_seqlen if calibration was used - base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {}) - if "calibration" in base_config and "max_seqlen" in base_config["calibration"]: - calib_max_seqlen = base_config["calibration"]["max_seqlen"] - if args.seq_len != calib_max_seqlen: - print( - f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} " - f"to match calibration config" - ) - args.seq_len = calib_max_seqlen +def generate_sample_output(model, tokenizer, args): + """Generate sample output for comparison. + + Args: + model: The model to generate with + tokenizer: Tokenizer for encoding/decoding + args: Command line arguments - # Load and prepare a single test prompt - print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") + Returns: + Tuple of (generated_text, input_prompt, input_ids) + """ + # Load test sample prompts = get_narrativeqa_samples(num_samples=1) prompt = prompts[0] # Prepare inputs truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) - display_prompt = ( - truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt - ) - inputs = tokenizer( truncated_prompt, return_tensors="pt", @@ -150,14 +145,7 @@ def verify_outputs(model, tokenizer, args): if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} - print("\n" + "=" * 60) - print("BASELINE vs SPARSE ATTENTION COMPARISON") - print("=" * 60) - print(f"\nTest prompt: {display_prompt}") - print(f"Input tokens: {inputs['input_ids'].shape[1]}") - - # Helper function to generate text - def generate_text(model, inputs, args, tokenizer): + # Generate with torch.no_grad(): outputs = model.generate( **inputs, @@ -168,60 +156,9 @@ def generate_text(model, inputs, args, tokenizer): ) input_length = inputs["input_ids"].shape[1] generated_ids = outputs[0][input_length:] - return tokenizer.decode(generated_ids, skip_special_tokens=True) - - # Find all sparse attention modules - sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] - - # Generate baseline by temporarily disabling sparse attention - print("\n" + "-" * 60) - print("Generating baseline (sparse attention disabled)...") - for module in sparse_modules: - module.disable() - baseline_text = generate_text(model, inputs, args, tokenizer) + generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) - # Generate with sparse attention enabled - print("\nGenerating with sparse attention (calibrated thresholds)...") - for module in sparse_modules: - module.enable() - sparse_text = generate_text(model, inputs, args, tokenizer) - - # Display comparison - print("\n" + "-" * 60) - print("RESULTS:") - baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text - sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text - - print(f"\nBaseline: {baseline_display}") - print(f"With Sparse: {sparse_display}") - - if baseline_text == sparse_text: - print("\nOutputs are identical") - else: - print("\nOutputs differ") - - -def sparsify_model(model, args): - """Apply sparse attention to the model with optional calibration.""" - print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}") - base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] - - # Create modified config with selected backend - modified_sparse_cfg = {} - for pattern, cfg in base_config["sparse_cfg"].items(): - modified_cfg = cfg.copy() - modified_cfg["backend"] = args.backend - modified_sparse_cfg[pattern] = modified_cfg - - # Create new config with modified settings - sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg) - - # Sparsify the model - model = mtsa.sparsify(model, config=sparse_config) - - print("Sparse attention applied successfully!") - - return model + return generated_text, truncated_prompt, inputs["input_ids"] def main(args): @@ -254,12 +191,40 @@ def main(args): model = model.cuda() print("Model moved to CUDA") - # Apply sparse attention to the model (with calibration if configured) - model = sparsify_model(model, args) + # Generate sample output BEFORE sparse attention + print("\nGenerating sample output before sparse attention...") + output_before, test_prompt, input_ids = generate_sample_output(model, tokenizer, args) + + # Apply sparse attention with optional calibration + print(f"\nApplying sparse attention: {args.sparse_attn}") + sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + model = mtsa.sparsify(model, config=sparse_config) + print("Sparse attention applied successfully!") + + # Generate sample output AFTER sparse attention + print("\nGenerating sample output after sparse attention...") + output_after, _, _ = generate_sample_output(model, tokenizer, args) + + # Display comparison + print("\n" + "=" * 60) + print("OUTPUT COMPARISON (Before vs After Sparse Attention)") + print("=" * 60) + display_prompt = test_prompt[:150] + "..." if len(test_prompt) > 150 else test_prompt + print(f"\nTest prompt: {display_prompt}") + print(f"Input tokens: {input_ids.shape[1]}") + + output_before_display = ( + output_before[:300] + "..." if len(output_before) > 300 else output_before + ) + output_after_display = output_after[:300] + "..." if len(output_after) > 300 else output_after + + print(f"\nBefore sparse attention: {output_before_display}") + print(f"After sparse attention: {output_after_display}") - # Verify outputs if requested (compares baseline vs calibrated sparse model) - if args.verify_output: - verify_outputs(model, tokenizer, args) + if output_before == output_after: + print("\nOutputs are identical") + else: + print("\nOutputs differ") # Export if requested if args.export_dir: @@ -306,12 +271,6 @@ def main(args): default=2048, help="Maximum sequence length for input prompts (will be truncated if longer)", ) - parser.add_argument( - "--num_samples", - type=int, - default=3, - help="Number of samples to use from NarrativeQA dataset", - ) # Generation arguments parser.add_argument( @@ -321,11 +280,6 @@ def main(args): parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") # Operation arguments - parser.add_argument( - "--verify_output", - action="store_true", - help="Verify that sparse attention outputs match baseline", - ) parser.add_argument( "--export_dir", type=str, diff --git a/examples/llm_sparsity/attention_sparsity/requirements.txt b/examples/llm_sparsity/attention_sparsity/requirements.txt new file mode 100644 index 000000000..a3e0dfa17 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/requirements.txt @@ -0,0 +1,2 @@ +nltk +wonderwords diff --git a/examples/llm_sparsity/weight_sparsity/finetune.py b/examples/llm_sparsity/weight_sparsity/finetune.py index 2feefc0fa..711084668 100644 --- a/examples/llm_sparsity/weight_sparsity/finetune.py +++ b/examples/llm_sparsity/weight_sparsity/finetune.py @@ -1,5 +1,6 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py + +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py - -# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py new file mode 100644 index 000000000..3b616e8e3 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +from .calibrate import calibrate_sparse_attention +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + +__all__ = [ + "DynamicThresholdCalibrator", + "RulerDatasetBuilder", + "calibrate_sparse_attention", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py new file mode 100644 index 000000000..1b8f0e71b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration functions for sparse attention.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +from transformers import AutoTokenizer + +from ..config import CalibrationConfig +from ..conversion import print_sparse_attention_summary +from ..sparse_attention import SparseAttentionModule +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + + +def _extract_tokenizer_from_model(model: nn.Module) -> str: + """Extract tokenizer name/path from model config. + + Args: + model: Model to extract tokenizer from + + Returns: + Tokenizer name or path + + Raises: + ValueError: If tokenizer path cannot be determined from model + """ + # Extract tokenizer path from model config + tokenizer_path = getattr(getattr(model, "config", None), "_name_or_path", None) + + if not tokenizer_path: + raise ValueError("Could not load tokenizer from model.") + + return tokenizer_path + + +def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | None: + """Extract and validate calibration config from sparse_cfg. + + Args: + config: Sparse attention configuration dict + + Returns: + Validated CalibrationConfig instance, or None if calibration is not configured + + Raises: + ValueError: If calibration config has invalid type or contains invalid values + """ + sparse_cfg = config.get("sparse_cfg", {}) + + # Calibration is optional + if "calibration" not in sparse_cfg: + return None + + calib_dict = sparse_cfg["calibration"] + + # Validate calibration is a dict + if not isinstance(calib_dict, dict): + raise ValueError(f"Calibration config must be a dict, got {type(calib_dict).__name__}. ") + + # Create and validate CalibrationConfig + return CalibrationConfig(**calib_dict) + + +def create_calibration_forward_loop( + calibration_data: list[dict[str, Any]], + tokenizer_name_or_path: str, + batch_size: int = 1, + chunk_size: int = 2048, +) -> Callable: + """Create forward loop for calibration. + + Args: + calibration_data: List of samples with 'input' and 'length' fields + tokenizer_name_or_path: HuggingFace tokenizer path + batch_size: Batch size (currently unused, always 1) + chunk_size: Chunk size for chunked prefill to avoid OOM. Set to -1 to disable. + + Returns: + Forward loop function that takes model as argument + """ + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + def forward_loop(model: nn.Module) -> None: + device = next(model.parameters()).device + + for sample in calibration_data: + inputs = tokenizer( + sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + input_ids = inputs["input_ids"].to(device) + seq_len = input_ids.shape[1] + + with torch.no_grad(): + if chunk_size > 0 and seq_len > chunk_size: + # Chunked prefill to avoid OOM with long sequences + past_key_values = None + for start_idx in range(0, seq_len, chunk_size): + end_idx = min(start_idx + chunk_size, seq_len) + chunk_input_ids = input_ids[:, start_idx:end_idx] + + outputs = model( + chunk_input_ids, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # Clean up KV cache + del past_key_values + torch.cuda.empty_cache() + else: + # Full prefill without chunking + model(input_ids, use_cache=False) + + return forward_loop + + +def calibrate_sparse_attention( + model: nn.Module, + config: dict[str, Any], + forward_loop: Callable | None = None, +) -> dict[str, Any]: + """Calibrate sparse attention parameters for optimal sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration dict + forward_loop: Callable that forwards calibration data through model. + If None, auto-generates RULER dataset. + + Returns: + Dictionary with calibration results + """ + # Extract and validate calibration config + calib_config = _extract_calibration_config(config) + + # Skip calibration if not configured + if calib_config is None: + return {} + + # Generate forward_loop if not provided + if not forward_loop: + tokenizer = _extract_tokenizer_from_model(model) + builder = RulerDatasetBuilder( + samples=calib_config.samples, + max_seqlen=calib_config.max_seqlen, + tokenizer_name_or_path=tokenizer, + num_length_bins=calib_config.num_length_bins, + max_length_filter=int(calib_config.max_seqlen * 1.5), + ) + calibration_data = builder.build_calibration_dataset() + print(f"Generated {len(calibration_data)} calibration samples") + forward_loop = create_calibration_forward_loop( + calibration_data, tokenizer, chunk_size=calib_config.chunk_size + ) + + # Get sparse attention modules + sparse_modules = [ + (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) + ] + + if not sparse_modules: + print("No sparse attention modules found for calibration") + return {} + + print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") + + # Run calibration + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=calib_config.target_sparse_ratio, + threshold_trials=calib_config.threshold_trials, + ) + calibration_result = calibrator.calibrate(model, forward_loop) + + # Print calibration statistics (regardless of success/failure for debugging) + print("\nCalibration complete!") + print_sparse_attention_summary(model) + + if "scale_factor" not in calibration_result: + warnings.warn("Calibration did not produce valid results") + return {} + + # Apply calibrated scale factor to all modules + scale_factor = calibration_result["scale_factor"] + print(f"\nApplying calibrated scale factor={scale_factor:.6f} to {len(sparse_modules)} modules") + + for module_name, module in sparse_modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + + return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}} diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py new file mode 100644 index 000000000..39abdbca8 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +import warnings +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +from ..sparse_attention import SparseAttentionModule +from ..stats_manager import SparseAttentionStatsManager + + +class DynamicThresholdCalibrator: + """Dynamic threshold calibrator using length-based linear relationship. + + Implements calibration algorithm: + 1. Find hyperparameter 'a' where threshold λ = a / context_length + 2. Use dataset with different lengths and test multiple thresholds + 3. For each sample, find optimal threshold closest to target sparsity + 4. Use linear regression to fit: threshold = a * (1/length) + """ + + @dataclass + class SampleSparsity: + """Sparsity results for a single calibration sample.""" + + length: int + threshold_sparsities: dict[float, float] + + def __init__( + self, + target_sparse_ratio: float = 0.5, + threshold_trials: list[float] | None = None, + ): + """Initialize dynamic threshold calibrator. + + Args: + target_sparse_ratio: Target sparsity ratio (0.0 to 1.0) + threshold_trials: List of thresholds to try during calibration + + Note: + Calibration only supports prefill phase (seq_len > 1). + Decode phase uses the same calibrated threshold. + """ + self.target_sparse_ratio = target_sparse_ratio + + # Default threshold trials if not provided + self.threshold_trials = threshold_trials or [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 5e-2, + 1e-1, + 5e-1, + ] + + # Statistics tracking + self.sparsity_results = [] + + def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: + """Find optimal 'a' parameter for length-based threshold. + + Algorithm: + 1. Test all threshold trials by running forward_loop multiple times + 2. For each sample, find optimal threshold closest to target sparsity + 3. Use regression to find 'a' in: threshold = a / length + + Args: + model: The model with sparse attention modules + forward_loop: Callable that takes model and forwards calibration data + """ + # Extract attention modules + attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + if not attention_modules: + raise ValueError("No sparse attention modules found for calibration") + + print("Starting dynamic threshold calibration") + print(f"Target sparsity: {self.target_sparse_ratio}") + print(f"Threshold trials: {len(self.threshold_trials)}") + + # Stage 1: Collect sparsity for all sample-threshold pairs + print("\nStage 1: Collecting sparsity data...") + + # Run first threshold to discover samples and initialize results + self._set_threshold(attention_modules, self.threshold_trials[0]) + self._enable_calibration_mode(attention_modules) + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules) + self._disable_calibration_mode(attention_modules) + + # Initialize sparsity_results with sample info + self.sparsity_results = [ + self.SampleSparsity( + length=stat["sample_length"], + threshold_sparsities={self.threshold_trials[0]: stat["sparsity"]}, + ) + for stat in per_sample_stats + ] + + # Collect remaining thresholds + for threshold in tqdm(self.threshold_trials[1:], desc="Testing thresholds"): + self._set_threshold(attention_modules, threshold) + self._enable_calibration_mode(attention_modules) + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules) + self._disable_calibration_mode(attention_modules) + + for sample_idx, sample_stat in enumerate(per_sample_stats): + self.sparsity_results[sample_idx].threshold_sparsities[threshold] = sample_stat[ + "sparsity" + ] + + if not self.sparsity_results: + warnings.warn("No valid sparsity measurements collected during calibration") + return {} + + print(f"Collected statistics for {len(self.sparsity_results)} samples") + + # Stage 2: Find optimal threshold for each sample and compute 'a' + print( + f"\nStage 2: Finding 'a' parameter for target sparsity {self.target_sparse_ratio:.2f}" + ) + + # Find optimal threshold for each sample + optimal_pairs = [] + for sample_result in self.sparsity_results: + # Find threshold closest to target sparsity + best_threshold, achieved_sparsity = min( + sample_result.threshold_sparsities.items(), + key=lambda item: abs(item[1] - self.target_sparse_ratio), + ) + + optimal_pairs.append( + { + "length": sample_result.length, + "optimal_threshold": best_threshold, + "achieved_sparsity": achieved_sparsity, + "target_sparsity": self.target_sparse_ratio, + } + ) + + if not optimal_pairs: + warnings.warn( + f"No optimal threshold pairs found for target sparsity {self.target_sparse_ratio}. " + f"Collected {len(self.sparsity_results)} samples but none achieved target sparsity." + ) + return {} + + # Linear regression: threshold = a * (1/length) + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + # X = 1/length, Y = threshold + x = 1.0 / lengths + y = thresholds + + # Least squares: scale_factor = sum(x*y) / sum(x^2) + scale_factor = np.sum(x * y) / np.sum(x**2) + + # Calculate statistics + scale_factors_per_sample = y * lengths + scale_factor_std = np.std(scale_factors_per_sample) + + # Calculate R-squared for quality metric + y_pred = scale_factor * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Calculate average achieved sparsity + avg_achieved_sparsity = np.mean([p["achieved_sparsity"] for p in optimal_pairs]) + + print("\nCalibration Results:") + print(f" Threshold scale factor: {scale_factor:.6f} (std: {scale_factor_std:.6f})") + print(f" R-squared: {r_squared:.4f}") + print( + f" Average achieved sparsity: {avg_achieved_sparsity:.2%} (target: {self.target_sparse_ratio:.2%})" + ) + print(f"\nExample thresholds with λ = {scale_factor:.6f} / length:") + for length in [1024, 2048, 4096, 8192, 16384]: + print(f" Length {length:5d}: threshold = {scale_factor / length:.2e}") + + # Apply the calibrated scale factor to modules + self._apply_length_based_calibration(attention_modules, scale_factor) + + return { + "scale_factor": scale_factor, + "scale_factor_std": scale_factor_std, + "r_squared": r_squared, + "num_samples": len(optimal_pairs), + "target_sparsity": self.target_sparse_ratio, + "avg_achieved_sparsity": avg_achieved_sparsity, + "optimal_pairs": optimal_pairs, + "calibration_type": "length_based_dynamic", + } + + def _apply_length_based_calibration(self, modules: list[nn.Module], scale_factor: float): + """Apply calibrated threshold scale factor to modules. + + Args: + modules: List of attention modules + scale_factor: Calibrated scale factor for λ = scale_factor / length + """ + for module in modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + + def _enable_calibration_mode(self, modules: list[nn.Module]): + """Enable calibration mode on sparse attention modules.""" + for idx, module in enumerate(modules): + # Create stats manager if needed + if not module._stats_manager: + module._stats_manager = SparseAttentionStatsManager( + module_name=f"sparse_attn_{idx}", enabled=True + ) + else: + # Re-enable if disabled + module._stats_manager.enabled = True + + # Enable calibration mode with fresh stats + module._stats_manager.set_calibration_mode(enabled=True, reset_history=True) + module._sparse_method_instance.set_calibration_mode(True) + + def _disable_calibration_mode(self, modules: list[nn.Module]): + """Disable calibration mode (but keep stats enabled if collect_stats=True).""" + for module in modules: + if module._stats_manager: + module._stats_manager.set_calibration_mode(enabled=False) + + module._sparse_method_instance.set_calibration_mode(False) + + def _extract_calibration_stats(self, modules: list[nn.Module]) -> list[dict]: + """Extract per-sample calibration statistics from modules. + + Args: + modules: List of attention modules + + Returns: + List of per-sample statistics across all modules + """ + # Collect from all stats managers + all_per_sample_stats = [] + + for module in modules: + # Skip modules without stats manager + if not hasattr(module, "_stats_manager") or module._stats_manager is None: + continue + + manager_stats = module._stats_manager.get_calibration_stats() + if manager_stats: + all_per_sample_stats.append(manager_stats) + + if not all_per_sample_stats: + return [] + + # Aggregate across modules by sample index + num_samples = len(all_per_sample_stats[0]) + aggregated_stats = [] + + for sample_idx in range(num_samples): + sparsities = [] + sample_length = 0 + + for module_stats in all_per_sample_stats: + if sample_idx < len(module_stats): + sample_stat = module_stats[sample_idx] + sparsities.append(sample_stat.get("sparsity", 0.0)) + if not sample_length and "sample_length" in sample_stat: + sample_length = sample_stat["sample_length"] + + avg_sparsity = float(np.mean(sparsities)) if sparsities else 0.0 + + aggregated_stats.append( + { + "sparsity": avg_sparsity, + "sample_length": sample_length, + } + ) + + return aggregated_stats + + def _set_threshold(self, modules: list[nn.Module], threshold: float): + """Set threshold on sparse attention modules.""" + for module in modules: + module._sparse_method_instance.threshold = threshold diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py new file mode 100644 index 000000000..7603b4e1d --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -0,0 +1,546 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RULER dataset builder for sparse attention calibration.""" + +import random +import string +from dataclasses import dataclass +from typing import Any + +from tqdm import tqdm +from transformers import AutoTokenizer + +from . import ruler_utils + + +def _generate_target_lengths( + max_seqlen: int, num_length_bins: int = 4, min_seqlen: int = 1024 +) -> list[int]: + """Generate target lengths as descending powers of 2. + + Args: + max_seqlen: Maximum sequence length + num_length_bins: Maximum number of length bins to generate + min_seqlen: Minimum sequence length threshold + + Returns: + List of target lengths in descending order + + Examples: + >>> _generate_target_lengths(32768, 4) + [32768, 16384, 8192, 4096] + >>> _generate_target_lengths(2048, 4) + [2048, 1024] + """ + target_lengths = [] + current = max_seqlen + + for _ in range(num_length_bins): + if current < min_seqlen: + break + target_lengths.append(current) + current = current // 2 + + return target_lengths + + +@dataclass +class RulerTask: + """Configuration for a RULER task.""" + + name: str + task_type: str # niah, variable_tracking, freq_words_extraction, qa + tokens_to_generate: int + template: str + answer_prefix: str + args: dict[str, Any] + + +# Task configurations based on RULER benchmark +RULER_TASKS = { + "niah_multikey_2": RulerTask( + name="niah_multikey_2", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "words", + "type_needle_v": "numbers", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "niah_multikey_3": RulerTask( + name="niah_multikey_3", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "uuids", + "type_needle_v": "uuids", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "vt": RulerTask( + name="vt", + task_type="variable_tracking", + tokens_to_generate=30, + template=( + "Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n" + "{context}\n" + "Question: Find all variables that are assigned the value {query} in the text above." + ), + answer_prefix=( + " Answer: According to the chain(s) of variable assignment in the text above, " + "{num_v} variables are assgined the value {query}, they are: " + ), + args={"num_chains": 1, "num_hops": 4}, + ), + "fwe": RulerTask( + name="fwe", + task_type="freq_words_extraction", + tokens_to_generate=50, + template=( + "Read the following coded text and track the frequency of each coded word. " + "Find the three most frequently appeared coded words. {context}\n" + "Question: Do not provide any explanation. Please ignore the dots '....'. " + "What are the three most frequently appeared words in the above coded text?" + ), + answer_prefix=( + " Answer: According to the coded text above, " + "the three most frequently appeared words are:" + ), + args={"alpha": 2.0}, + ), + "qa_1": RulerTask( + name="qa_1", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "squad"}, + ), + "qa_2": RulerTask( + name="qa_2", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "hotpotqa"}, + ), +} + + +class RulerDatasetBuilder: + """Builder for RULER calibration datasets.""" + + def __init__( + self, + samples: int, + max_seqlen: int, + tokenizer_name_or_path: str | object, + num_length_bins: int = 4, + max_length_filter: int = 65536, + seed: int = 42, + ): + """Initialize RULER dataset builder. + + Args: + samples: Total number of samples to generate (distributed evenly across length bins) + max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2) + tokenizer_name_or_path: HuggingFace tokenizer path or tokenizer object + seed: Random seed for reproducibility + num_length_bins: Number of length bins to generate (default: 4) + max_length_filter: Maximum sequence length to keep (default: 65536) + + Note: + Length bins are auto-generated as descending powers of 2: + [max_seqlen, max_seqlen/2, max_seqlen/4, ...] + Generation stops when num_length_bins is reached or length < 1024. + Subtasks are set to all the difficult tasks defined in RULER_TASKS. + """ + # Validate inputs + if samples <= 0: + raise ValueError(f"samples must be positive, got {samples}") + if max_seqlen < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {max_seqlen}") + + # Store parameters + self.total_samples = samples + self.max_seqlen = max_seqlen + self.num_length_bins = num_length_bins + self.subtasks = list(RULER_TASKS.keys()) + self.tokenizer_name_or_path = tokenizer_name_or_path + self.seed = seed + self.max_length_filter = max_length_filter + + # Generate target lengths and validate + self.target_lengths = _generate_target_lengths(max_seqlen, num_length_bins, min_seqlen=1024) + if not self.target_lengths: + raise ValueError(f"No valid target lengths generated from max_seqlen={max_seqlen}") + + # Distribute samples evenly across lengths + self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths) + + # Initialize tokenizer + if isinstance(tokenizer_name_or_path, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + else: + self.tokenizer = tokenizer_name_or_path + random.seed(seed) + + def build_calibration_dataset(self) -> list[dict[str, Any]]: + """Build the complete calibration dataset. + + Returns: + List of calibration samples with 'input' and 'length' fields + """ + all_samples = [] + + # Generate calibration samples + for num_samples, target_length in tqdm( + zip(self.samples_per_length, self.target_lengths), + desc="Generating RULER calibration samples", + total=len(self.target_lengths), + ): + samples_per_task = max(num_samples // len(self.subtasks), 1) + + # Generate equal samples for each task + for task_name in self.subtasks: + for sample_idx in range(samples_per_task): + sample = self._generate_sample(task_name, target_length, sample_idx) + if sample and sample["length"] <= self.max_length_filter: + all_samples.append(sample) + + random.shuffle(all_samples) + return all_samples + + def _generate_sample( + self, task_name: str, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a single RULER sample. + + Args: + task_name: Name of the RULER task + target_length: Target sequence length in tokens + sample_idx: Index of the sample (for uniqueness) + + Returns: + Dict with 'input', 'length', and metadata fields + """ + task = RULER_TASKS[task_name] + + if task.task_type == "niah": + return self._generate_niah_sample(task, target_length, sample_idx) + elif task.task_type == "variable_tracking": + return self._generate_vt_sample(task, target_length, sample_idx) + elif task.task_type == "freq_words_extraction": + return self._generate_fwe_sample(task, target_length, sample_idx) + elif task.task_type == "qa": + return self._generate_qa_sample(task, target_length, sample_idx) + else: + raise ValueError(f"Unknown task type: {task.task_type}") + + def _generate_niah_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a needle-in-haystack sample.""" + args = task.args + + # Find optimal haystack size for target length + optimal_haystack = ruler_utils.find_optimal_haystack_size( + tokenizer=self.tokenizer, + max_seq_length=target_length, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + ) + + # Generate sample using official RULER implementation + sample = ruler_utils.generate_niah_sample( + num_haystack=optimal_haystack, + tokenizer=self.tokenizer, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + random_seed=self.seed + sample_idx, + ) + + # Add task metadata + sample["task"] = task.name + sample["target_length"] = target_length + sample["sample_idx"] = sample_idx + + return sample + + def _generate_vt_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a variable tracking sample.""" + args = task.args + num_chains = args["num_chains"] + num_hops = args["num_hops"] + + # Generate variable chains + variables = [] + chains = [] + for _ in range(num_chains): + chain = [self._generate_random_variable() for _ in range(num_hops + 1)] + variables.extend(chain) + chains.append(chain) + + # Generate assignments + assignments = [ + f"VAR {chain[i]} = {chain[i + 1]}" for chain in chains for i in range(len(chain) - 1) + ] + + # Create context with padding + context = self._pad_context_with_text( + "\n".join(assignments), target_length, "variable tracking context" + ) + + # Select a query value + query_value = random.choice([chain[-1] for chain in chains]) + + # Format template + template = task.template.format(context=context, query=query_value) + + # Count variables with the query value + num_v = sum(1 for chain in chains if chain[-1] == query_value) + + # Add answer prefix + full_input = template + task.answer_prefix.format(num_v=num_v, query=query_value) + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_fwe_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a frequency word extraction sample.""" + # Generate coded words with frequencies + num_unique_words = 50 + coded_words = [self._generate_coded_word() for _ in range(num_unique_words)] + + # Assign frequencies (make top 3 clearly more frequent) + frequencies = {} + for i, word in enumerate(coded_words): + if i < 3: + frequencies[word] = random.randint(20, 30) # High frequency + else: + frequencies[word] = random.randint(1, 10) # Low frequency + + # Generate the coded text + word_list = [] + for word, freq in frequencies.items(): + word_list.extend([word] * freq) + random.shuffle(word_list) + + # Add dots for separation + coded_text = " .... ".join(word_list) + + # Pad to target length + context = self._pad_context_with_text(coded_text, target_length, "coded text padding") + + # Format template + template = task.template.format(context=context) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_qa_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a QA sample.""" + # Generate synthetic documents + num_docs = 5 + documents = [] + + # Create a simple QA pair + answer = self._generate_random_phrase() + question = f"What is the special code mentioned in document {random.randint(1, num_docs)}?" + + for i in range(num_docs): + doc_text = self._generate_document_text(200) # Base document + if i == 2: # Insert answer in one document + doc_text += f" The special code is {answer}. " + documents.append(f"Document {i + 1}:\n{doc_text}\n") + + # Combine documents + context_base = "\n".join(documents) + + # Pad to target length + context = self._pad_context_with_text( + context_base, target_length, "additional document text" + ) + + # Format template + template = task.template.format(context=context, query=question) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _pad_context_with_text( + self, base_context: str, target_length: int, padding_type: str + ) -> str: + """Pad context to approach target length.""" + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + while len(tokens) < target_length * 0.7: # Leave room for template + if padding_type == "variable tracking context": + padding = ( + f" VAR {self._generate_random_variable()} = {self._generate_random_variable()}." + ) + elif padding_type == "coded text padding": + padding = f" .... {self._generate_coded_word()} .... " + else: + padding = " " + self._generate_essay_text(50) + + base_context += padding + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + if len(tokens) > target_length * 0.9: + # Truncate if too long + base_context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) + + return base_context + + def _generate_random_word(self) -> str: + """Generate a random word.""" + return "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 10))) + + def _generate_random_variable(self) -> str: + """Generate a random variable name.""" + return "".join(random.choices(string.ascii_uppercase, k=1)) + "".join( + random.choices(string.digits, k=3) + ) + + def _generate_coded_word(self) -> str: + """Generate a coded word.""" + return "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) + + def _generate_random_phrase(self) -> str: + """Generate a random phrase.""" + words = [self._generate_random_word() for _ in range(random.randint(2, 4))] + return " ".join(words) + + def _generate_essay_text(self, num_words: int) -> str: + """Generate essay-like text.""" + topics = [ + "technology", + "science", + "nature", + "history", + "culture", + "education", + "health", + "economics", + "politics", + "philosophy", + "art", + "literature", + ] + + sentences = [] + words_generated = 0 + + while words_generated < num_words: + topic = random.choice(topics) + word1 = self._generate_random_word() + word2 = self._generate_random_word() + word3 = self._generate_random_word() + sentence = f"The {topic} of {word1} is {word2} and {word3}. " + sentences.append(sentence) + words_generated += len(sentence.split()) + + return " ".join(sentences) + + def _generate_document_text(self, num_words: int) -> str: + """Generate document-like text.""" + return self._generate_essay_text(num_words) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh new file mode 100755 index 000000000..54797f2a5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Download RULER calibration data for attention sparsity. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DATA_DIR="${SCRIPT_DIR}/data" +ESSAYS_DIR="${DATA_DIR}/essays" +URLS_FILE="${DATA_DIR}/PaulGrahamEssays_URLs.txt" +URLS_URL="https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt" + +mkdir -p "${ESSAYS_DIR}" + +# Download URL list if not exists +if [ ! -f "${URLS_FILE}" ]; then + echo "Downloading URL list..." + curl -fsSL "${URLS_URL}" -o "${URLS_FILE}" +fi + +# Download essays from GitHub URLs +echo -n "Downloading essays" +count=0 +while IFS= read -r url || [ -n "$url" ]; do + if [[ "${url}" == https://github.com*.txt ]]; then + filename=$(basename "${url}") + filepath="${ESSAYS_DIR}/${filename}" + if [ ! -f "${filepath}" ]; then + raw_url="${url/github.com/raw.githubusercontent.com}" + raw_url="${raw_url/\/raw\//\/}" + curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "." + count=$((count + 1)) + fi + fi +done < "${URLS_FILE}" +echo " done" + +echo "Downloaded ${count} essays to ${ESSAYS_DIR}" diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py new file mode 100644 index 000000000..70d4da81b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied and Adapted from https://github.com/NVIDIA/RULER +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Official RULER dataset generation utilities adapted for Model Optimizer. + +This module contains core logic from the RULER benchmark (https://github.com/NVIDIA/RULER) +adapted to work as a library for calibration purposes. The generation logic closely follows +the official RULER implementation to ensure dataset consistency. + +Key adaptations from official RULER: +- Converted from CLI scripts to library functions +- Works with HuggingFace tokenizers directly +- Removed file I/O, returns data structures +- Simplified for calibration use case (primarily NIAH tasks) +""" + +import logging +import random +import re +import uuid +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +# Needle/Haystack template from official RULER +NEEDLE_TEMPLATE = "One of the special magic {type_needle_v} for {key} is: {value}." + +# Depth positions for needle insertion (from official RULER) +DEPTHS = [ + 0, + 2, + 5, + 7, + 10, + 12, + 15, + 18, + 20, + 23, + 25, + 28, + 30, + 33, + 35, + 38, + 40, + 43, + 45, + 48, + 50, + 53, + 55, + 58, + 60, + 62, + 65, + 67, + 70, + 72, + 75, + 77, + 80, + 82, + 85, + 87, + 90, + 92, + 95, + 97, + 100, +] + +# Data directory for RULER calibration files (downloaded via download_ruler_data.sh) +DATA_DIR = Path(__file__).parent / "data" +RULER_URLS_FILE = DATA_DIR / "PaulGrahamEssays_URLs.txt" +ESSAYS_DIR = DATA_DIR / "essays" + + +def _get_data_dir() -> Path: + """Get data directory for RULER data. + + Returns: + Path to data directory under calibration/ (created if doesn't exist) + """ + data_dir = Path(__file__).parent / "data" + data_dir.mkdir(parents=True, exist_ok=True) + return data_dir + + +def _load_paul_graham_essays_from_files() -> str: + """Load Paul Graham essays from local files. + + Reads essay .txt files from the data/essays directory. + Files must be downloaded first using download_ruler_data.sh. + + Returns: + Combined essay text + + Raises: + RuntimeError: If essays directory doesn't exist or is empty + """ + if not ESSAYS_DIR.exists(): + raise RuntimeError( + f"Essays directory not found at {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + essay_files = list(ESSAYS_DIR.glob("*.txt")) + if not essay_files: + raise RuntimeError( + f"No essay files found in {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + logger.info(f"Loading {len(essay_files)} Paul Graham essays from local files...") + + all_essays = [] + for filepath in essay_files: + text = filepath.read_text() + all_essays.append(text) + + combined_text = " ".join(all_essays) + logger.info(f"Loaded {len(all_essays)} essays successfully") + + return combined_text + + +def _load_paul_graham_essays() -> str: + """Load Paul Graham essays from local files. + + Essay files must be downloaded first using download_ruler_data.sh. + + Returns: + Essay text as string + """ + essay_text = _load_paul_graham_essays_from_files() + return re.sub(r"\s+", " ", essay_text) + + +def _load_word_lists(): + """Load word lists for random word generation. + + Returns: + List of words (adj-noun combinations) + """ + import wonderwords + + # Load wonderwords lists (same as official RULER) + nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") + adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") + words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] + words = sorted(set(words)) + return words + + +# Global word list (loaded once) +_WORD_LIST = None + + +def generate_random_number(num_digits=7) -> str: + """Generate random number (from official RULER).""" + lower_bound = 10 ** (num_digits - 1) + upper_bound = 10**num_digits - 1 + return str(random.randint(lower_bound, upper_bound)) + + +def generate_random_word() -> str: + """Generate random word (from official RULER).""" + global _WORD_LIST + if _WORD_LIST is None: + _WORD_LIST = _load_word_lists() + return random.choice(_WORD_LIST) + + +def generate_random_uuid() -> str: + """Generate random UUID (from official RULER).""" + return str(uuid.UUID(int=random.getrandbits(128), version=4)) + + +def generate_random(type_needle: str) -> str: + """Generate random needle value based on type (from official RULER). + + Args: + type_needle: Type of needle ('numbers', 'words', 'uuids') + + Returns: + Random value as string + """ + if type_needle == "numbers": + return generate_random_number() + elif type_needle == "words": + return generate_random_word() + elif type_needle == "uuids": + return generate_random_uuid() + else: + raise ValueError(f"Unknown needle type: {type_needle}") + + +def generate_niah_sample( + num_haystack: int, + tokenizer, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + type_needle_k: str = "words", + type_needle_v: str = "numbers", + num_needle_k: int = 1, + num_needle_v: int = 1, + num_needle_q: int = 1, + random_seed: int = 42, +) -> dict[str, Any]: + """Generate a single NIAH (Needle in a Haystack) sample. + + This function implements the core generation logic from official RULER's niah.py, + adapted to work as a library function. + + Args: + num_haystack: Number of haystack items/words + tokenizer: HuggingFace tokenizer (AutoTokenizer instance) + template: NIAH question template + answer_prefix: Answer prefix template + tokens_to_generate: Expected number of generation tokens + type_haystack: Type of haystack ('essay', 'noise', 'needle') + type_needle_k: Type of needle keys ('numbers', 'words', 'uuids') + type_needle_v: Type of needle values ('numbers', 'words', 'uuids') + num_needle_k: Number of needle keys + num_needle_v: Number of needle values per key + num_needle_q: Number of needles to query + random_seed: Random seed for this sample + + Returns: + Dictionary with 'input', 'outputs', 'length' keys + """ + import nltk + from nltk.tokenize import sent_tokenize + + try: + nltk.data.find("tokenizers/punkt") + except LookupError: + nltk.download("punkt", quiet=True) + nltk.download("punkt_tab", quiet=True) + + if random_seed is not None: + random.seed(random_seed) + + # Ensure num_needle_k >= num_needle_q + num_needle_k = max(num_needle_k, num_needle_q) + + # Generate needles (keys and values) + keys, values, needles = [], [], [] + for _ in range(num_needle_k): + keys.append(generate_random(type_needle_k)) + value = [] + for _ in range(num_needle_v): + value.append(generate_random(type_needle_v)) + needles.append( + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=keys[-1], + value=value[-1], + ) + ) + values.append(value) + + random.shuffle(needles) + + # Generate context based on haystack type + if type_haystack == "essay": + # Load essay corpus + essay_text = _load_paul_graham_essays() + haystack = essay_text.split(" ") + + # Create text from haystack + if num_haystack <= len(haystack): + text = " ".join(haystack[:num_haystack]) + else: + # Repeat haystack as needed + repeats = (num_haystack + len(haystack) - 1) // len(haystack) + text = " ".join((haystack * repeats)[:num_haystack]) + + # Insert needles at various depths + document_sents = sent_tokenize(text.strip()) + insertion_positions = [ + 0, + *sorted( + int(len(document_sents) * (depth / 100)) + for depth in random.sample(DEPTHS, len(needles)) + ), + len(document_sents), + ] + + document_sents_list = [] + for i in range(1, len(insertion_positions)): + last_pos = insertion_positions[i - 1] + next_pos = insertion_positions[i] + document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) + if i - 1 < len(needles): + document_sents_list.append(needles[i - 1]) + + context = " ".join(document_sents_list) + + if type_haystack == "noise": + haystack_sent = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." + sentences = [haystack_sent] * num_haystack + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + elif type_haystack == "needle": + sentences = [ + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=generate_random(type_needle_k), + value=generate_random(type_needle_v), + ) + for _ in range(num_haystack) + ] + + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + # Generate query and answer + indices = random.sample(range(num_needle_k), num_needle_q) + queries = [keys[i] for i in indices] + answers = [a for i in indices for a in values[i]] + query = ", ".join(queries[:-1]) + ", and " + queries[-1] if len(queries) > 1 else queries[0] + + # Format template (adjust for singular vs plural) + type_needle_v_display = type_needle_v + formatted_template = template + if num_needle_q * num_needle_v == 1: + formatted_template = formatted_template.replace("Some", "A") + formatted_template = formatted_template.replace("are all", "is") + formatted_template = formatted_template.replace("are", "is") + formatted_template = formatted_template.replace("answers", "answer") + type_needle_v_display = type_needle_v[:-1] # remove "s" + + input_text = formatted_template.format( + type_needle_v=type_needle_v_display, + context=context, + query=query, + ) + + # Add answer prefix + formatted_answer_prefix = answer_prefix.format( + type_needle_v=type_needle_v_display, + query=query, + ) + input_text = input_text + formatted_answer_prefix + + # Calculate actual length + if hasattr(tokenizer, "encode"): + # HuggingFace tokenizer + tokens = tokenizer.encode(input_text, add_special_tokens=False) + length = len(tokens) + tokens_to_generate + else: + # Fallback + length = len(input_text.split()) + tokens_to_generate + + return { + "input": input_text, + "outputs": answers, + "length": length, + } + + +def find_optimal_haystack_size( + tokenizer, + max_seq_length: int, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + **kwargs, +) -> int: + """Find optimal haystack size using binary search (from official RULER). + + Args: + tokenizer: HuggingFace tokenizer + max_seq_length: Maximum sequence length + tokens_to_generate: Expected generation tokens + type_haystack: Type of haystack + template: NIAH question template + answer_prefix: Answer prefix template + **kwargs: Additional arguments for generate_niah_sample + + Returns: + Optimal number of haystack items + """ + # Determine incremental step based on haystack type + if type_haystack == "essay": + incremental = 500 + elif type_haystack in ["noise", "needle"]: + incremental = 25 + else: + incremental = 100 + + if max_seq_length < 4096 and type_haystack != "essay": + incremental = 5 + + # Estimate tokens per haystack item + sample = generate_niah_sample( + incremental, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + + if hasattr(tokenizer, "encode"): + sample_tokens = len(tokenizer.encode(sample["input"], add_special_tokens=False)) + else: + sample_tokens = len(sample["input"].split()) + + tokens_per_haystack = sample_tokens / incremental + estimated_max = int((max_seq_length / tokens_per_haystack) * 3) + + # Binary search for optimal size + lower_bound = incremental + upper_bound = max(estimated_max, incremental * 2) + optimal_num_haystack = None + + logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") + logger.info(f"Binary search bounds: {lower_bound} to {upper_bound}") + + while lower_bound <= upper_bound: + mid = (lower_bound + upper_bound) // 2 + sample = generate_niah_sample( + mid, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + total_tokens = sample["length"] + + logger.debug(f"Testing haystack size: {mid}, tokens: {total_tokens}/{max_seq_length}") + + if total_tokens <= max_seq_length: + optimal_num_haystack = mid + lower_bound = mid + 1 + else: + upper_bound = mid - 1 + + final_size = optimal_num_haystack if optimal_num_haystack is not None else incremental + logger.info(f"Optimal haystack size: {final_size}") + + return final_size diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index e72dacc94..8271dd4a2 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -77,6 +77,12 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass for monitoring.", + ) + is_causal: bool = ModeloptField( default=True, title="Causal attention flag.", @@ -87,16 +93,6 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) - calibration: dict | None = ModeloptField( - default=None, - title="Calibration configuration", - description=( - "Calibration settings for this pattern. " - "If provided, enables automatic threshold calibration. " - "Only one pattern should have calibration enabled." - ), - ) - @field_validator("method") @classmethod def validate_method(cls, v): @@ -150,24 +146,113 @@ def validate_threshold(cls, v): return v -# Pre-defined Sparse Attention Configuration -# Default configuration with block-wise sparsity optimized for Flash Attention -SKIP_SOFTMAX_DEFAULT = { - "sparse_cfg": { - "*attn*": { - "method": "flash_skip_softmax", - "threshold": { - "prefill": 1e-3, # More aggressive during prefill - "decode": 1e-4, # Conservative during decode - }, - "br": 128, # Flash Attention block rows - "bc": 128, # Flash Attention block columns - "backend": "pytorch", # Only pytorch backend supported - "enable": True, - }, - "default": {"enable": False}, - }, -} +class CalibrationConfig(ModeloptBaseConfig): + """Configuration for automatic threshold calibration using RULER dataset. + + Calibration learns a dynamic threshold λ = scale_factor / sequence_length that + achieves target sparsity. Only supports prefill phase (seq_len > 1). + """ + + target_sparse_ratio: float = ModeloptField( + default=0.5, + title="Target sparsity ratio", + description="Target ratio of sparse attention blocks (0.0 to 1.0).", + ) + + samples: int = ModeloptField( + default=24, + title="Calibration samples", + description="Total number of RULER samples for calibration (distributed across length bins).", + ) + + max_seqlen: int = ModeloptField( + default=32768, + title="Maximum sequence length", + description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", + ) + + num_length_bins: int = ModeloptField( + default=4, + title="Number of length bins", + description="Number of length bins to generate (hidden parameter, default: 4).", + ) + + chunk_size: int = ModeloptField( + default=2048, + title="Chunk size for prefill", + description=( + "Chunk size for chunked prefill to avoid OOM with long sequences. " + "When sequence length exceeds chunk_size, prefill is done in chunks using KV cache. " + "Set to -1 to disable chunking (full prefill)." + ), + ) + + threshold_trials: list[float] | None = ModeloptField( + default=None, + title="Threshold trials", + description=( + "List of threshold values to test during calibration. " + "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" + ), + ) + + @field_validator("threshold_trials") + @classmethod + def validate_threshold_trials(cls, v): + """Validate threshold_trials are in valid range.""" + if v is not None: + if not isinstance(v, list): + raise ValueError(f"threshold_trials must be a list, got {type(v)}") + if len(v) == 0: + raise ValueError("threshold_trials must not be empty") + for threshold in v: + if not isinstance(threshold, (int, float)): + raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") + if threshold <= 0 or threshold >= 1: + raise ValueError( + f"All threshold_trials must be in range (0, 1), got {threshold}" + ) + return v + + @field_validator("target_sparse_ratio") + @classmethod + def validate_target_sparse_ratio(cls, v): + """Validate target sparsity ratio is between 0 and 1.""" + if not 0.0 <= v <= 1.0: + raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") + return v + + @field_validator("samples") + @classmethod + def validate_samples(cls, v): + """Validate samples is positive.""" + if v <= 0: + raise ValueError(f"samples must be positive, got {v}") + return v + + @field_validator("max_seqlen") + @classmethod + def validate_max_seqlen(cls, v): + """Validate max_seqlen is at least 1024.""" + if v < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {v}") + return v + + @field_validator("num_length_bins") + @classmethod + def validate_num_length_bins(cls, v): + """Validate num_length_bins is positive.""" + if v <= 0: + raise ValueError(f"num_length_bins must be positive, got {v}") + return v + + @field_validator("chunk_size") + @classmethod + def validate_chunk_size(cls, v): + """Validate chunk_size is positive or -1 (disabled).""" + if v != -1 and v <= 0: + raise ValueError(f"chunk_size must be positive or -1 (disabled), got {v}") + return v class SparseAttentionConfig(ModeloptBaseConfig): @@ -184,8 +269,9 @@ class SparseAttentionConfig(ModeloptBaseConfig): "default": {"enable": False}, }, title="Sparse attention configuration", - description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " - "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", + description="Pattern-based configuration for sparse attention. Keys are patterns to match module names " + "(or 'calibration' for global calibration settings), values are configuration dicts with parameters like " + "'threshold', 'enable', etc.", validate_default=True, ) @@ -198,15 +284,17 @@ class SparseAttentionConfig(ModeloptBaseConfig): class FlashSkipSoftmaxConfig(SparseAttentionConfig): """Configuration for Flash Attention-aware softmax skip sparse attention.""" + # Override sparse_cfg with flash_skip_softmax specific defaults # Override sparse_cfg with flash_skip_softmax specific defaults sparse_cfg: SparseAttentionCfgType = ModeloptField( default={ "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "threshold": {"prefill": 1e-3, "decode": 1e-5}, "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, # Enable statistics collection "enable": True, }, "default": {"enable": False}, @@ -218,8 +306,55 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): ) +# Pre-defined Sparse Attention Configuration +# Default configuration with block-wise sparsity optimized for Flash Attention +SKIP_SOFTMAX_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-4, # Conservative during decode + }, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +# Configuration with RULER calibration +# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# The calibrated threshold adapts to sequence length for optimal sparsity +SKIP_SOFTMAX_CALIB = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.75, + "samples": 16, + "max_seqlen": 16384, + }, + "*attn*": { + "method": "flash_skip_softmax", + "br": 128, + "bc": 128, + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + __all__ = [ + "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "CalibrationConfig", + "FlashSkipSoftmaxConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", "SparseAttentionCfgType", diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index ad137e9ee..aa3eb7c29 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -226,6 +226,8 @@ def update_sparse_attention_metadata( if isinstance(module, SparseAttentionModule): module_name = get_unwrapped_name(name, model) + # Save the method configuration that was used + # _method_config already contains the validated config dict # Save the method configuration that was used # _method_config already contains the validated config dict module_state = { @@ -299,3 +301,44 @@ def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Cal if matched: module.enable() + + +def _format_threshold(info: dict) -> str: + """Format threshold info for display.""" + t = info.get("type") + if t == "dynamic": + return f"λ={info.get('scale_factor', 0):.2f}" + if t == "static": + v = info.get("value") + if isinstance(v, dict): + return f"threshold={v}" + return f"threshold={v:.2e}" if isinstance(v, float) else f"threshold={v}" + return "threshold=N/A" + + +def print_sparse_attention_summary(model: nn.Module): + """Print summary of sparse attention modules in the model. + + Args: + model: Model with sparse attention applied + """ + sparse_modules = [ + (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) + ] + + if not sparse_modules: + print("No sparse attention modules found") + return + + enabled = sum(1 for _, m in sparse_modules if m.is_enabled) + print(f"Sparse attention: {enabled}/{len(sparse_modules)} modules enabled") + + # Group by (method, threshold) + groups: dict[tuple[str, str], int] = {} + for _, module in sparse_modules: + method = getattr(module, "_method", "unknown") + threshold = _format_threshold(module.get_threshold_info()) + groups[(method, threshold)] = groups.get((method, threshold), 0) + 1 + + for (method, threshold), count in sorted(groups.items()): + print(f" {method}: {count} layers, {threshold}") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 8801bafb0..0e4512c98 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -20,6 +20,7 @@ """ import math +from typing import Any import numpy as np import torch @@ -44,7 +45,7 @@ def __init__(self, method_config: dict | None = None): """ config = method_config or {} - # Extract configuration (defaults handled by Pydantic) + # Extract configuration self.threshold_config = config["threshold"] self.br = config["br"] self.bc = config["bc"] @@ -52,9 +53,11 @@ def __init__(self, method_config: dict | None = None): self.is_causal = config["is_causal"] # Optional parameters not in Pydantic config - self.enable_correction_factor = config.get("enable_correction_factor", True) self.phase = config.get("phase", None) + # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold + self._calibration_mode = False + # Initialize threshold if isinstance(self.threshold_config, dict): self.threshold = self.threshold_config.get( @@ -63,6 +66,10 @@ def __init__(self, method_config: dict | None = None): else: self.threshold = self.threshold_config + def set_calibration_mode(self, enabled: bool): + """Set calibration mode to prevent _update_threshold from modifying the threshold.""" + self._calibration_mode = enabled + def _update_threshold(self, phase: str): """Update threshold based on phase.""" if isinstance(self.threshold_config, dict): @@ -184,18 +191,15 @@ def calc_correction_factor_and_p( element_mask = element_mask[:, :, :seq_q, :seq_k] # Step 8: Calculate sparsity statistics - # Count kept blocks (averaged across batch and heads) - kept_blocks = block_mask.sum().item() / (batch_size * num_heads) - - # Total valid blocks (lower triangle only for causal attention) - # Note: Causal mask pre-applied by attention module, so block_mask naturally - # has zeros in upper triangle. We only count lower triangle for denominator. - total_blocks = ( - num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 - if self.is_causal - else num_block_rows * num_block_cols # Non-causal: N*N - ) - sparsity = 1 - (kept_blocks / total_blocks) + # density = sum(mask) / numel(mask) * N / (N+1) for causal + if self.is_causal: + density = float(block_mask.sum() / block_mask.numel()) * ( + num_block_rows / (num_block_rows + 1) + ) + else: + density = float(block_mask.sum() / block_mask.numel()) + sparsity = 1 - density + total_blocks = num_block_rows * num_block_cols else: # decode blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( attn_weights, 1, self.bc @@ -232,14 +236,14 @@ def calc_correction_factor_and_p( element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) element_mask = element_mask[:, :, :seq_q, :seq_k] - # Step 7: Calculate statistics - kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + # Step 7: Calculate sparsity statistics + density = float(block_mask.sum() / block_mask.numel()) + sparsity = 1 - density total_blocks = num_block_cols - sparsity = 1 - (kept_blocks / total_blocks) # Create stats dictionary stats = { - "correction_factor": correction_factor if self.enable_correction_factor else 1.0, + "correction_factor": correction_factor, "sparsity": sparsity, "phase": phase, "total_blocks": total_blocks, @@ -249,27 +253,18 @@ def calc_correction_factor_and_p( return element_mask, stats - def apply_sparsity( + def calculate_sparsity( self, - query: torch.Tensor | None = None, - key: torch.Tensor | None = None, - value: torch.Tensor | None = None, - attention_scores: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: - """Apply Flash Attention-aware block-wise sparsity. + attention_scores: torch.Tensor, + ) -> tuple[torch.Tensor, dict]: + """Calculate sparsity mask and statistics for Flash Attention. Args: - query: Query tensor (unused, for API compatibility) - key: Key tensor (unused, for API compatibility) - value: Value tensor (unused, for API compatibility) attention_scores: Attention scores tensor with shape [batch, heads, seq_q, seq_k] Returns: - Tuple with potentially modified attention_scores + Tuple of (sparse_mask, stats) where sparse_mask is boolean mask """ - # Attention scores must be provided for sparse attention - assert attention_scores is not None, "attention_scores must be provided for apply_sparsity" - # Attention scores are always 4D: [batch, heads, seq_q, seq_k] assert len(attention_scores.shape) == 4, ( f"Expected 4D attention scores, got shape {attention_scores.shape}" @@ -278,20 +273,66 @@ def apply_sparsity( # Infer phase from tensor shape phase = self._infer_phase(attention_scores) - # Update threshold for the detected phase - self._update_threshold(phase) + # Update threshold for the detected phase (skip during calibration) + if not self._calibration_mode: + self._update_threshold(phase) - # Apply block-wise sparsity + # Calculate block-wise sparsity mask and stats sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) - # Store stats for module to collect (doesn't persist across calls) + # Store stats for module to collect self._last_stats = stats - # Apply mask to create sparse scores + return sparse_mask, stats + + def apply_sparsity( + self, + attention_scores: torch.Tensor, + sparse_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Apply sparsity mask to attention scores. + + Args: + attention_scores: Attention scores tensor [batch, heads, seq_q, seq_k] + sparse_mask: Optional pre-computed boolean mask. If None, calculates internally. + + Returns: + Masked attention scores with sparse elements set to dtype minimum + """ + if sparse_mask is None: + sparse_mask, _ = self.calculate_sparsity(attention_scores) + + # Apply mask: set masked positions to minimum value (becomes 0 after softmax) mask_value = torch.finfo(attention_scores.dtype).min - sparse_scores = attention_scores.masked_fill(~sparse_mask, mask_value) + return attention_scores.masked_fill(~sparse_mask, mask_value) + + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for this method. - return query, key, value, sparse_scores + Returns: + Dictionary with threshold configuration and calibration info. + """ + threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + + if threshold_scale_factor is not None: + # Calibrated dynamic threshold + return { + "type": "dynamic", + "scale_factor": threshold_scale_factor, + "formula": "λ / length", + "example_lengths": { + 1024: threshold_scale_factor / 1024, + 2048: threshold_scale_factor / 2048, + 4096: threshold_scale_factor / 4096, + 8192: threshold_scale_factor / 8192, + }, + } + else: + # Static threshold (single value or phase-specific dict) + return { + "type": "static", + "value": self.threshold_config, + } @property def name(self) -> str: diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index df7b5853b..996095127 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -18,6 +18,7 @@ import re import warnings from abc import ABC, abstractmethod +from typing import Any import torch @@ -25,25 +26,49 @@ class SparseAttentionMethod(ABC): """Base class for sparse attention methods.""" + @abstractmethod + def calculate_sparsity( + self, + attention_scores: torch.Tensor, + ) -> tuple[torch.Tensor, dict]: + """Calculate sparsity mask and statistics without applying. + + Args: + attention_scores: Pre-softmax attention scores [batch, heads, seq_q, seq_k] + + Returns: + Tuple of (sparse_mask, stats_dict) where: + - sparse_mask: Boolean tensor indicating which elements to keep + - stats_dict: Dictionary with sparsity statistics + """ + @abstractmethod def apply_sparsity( self, - query: torch.Tensor | None = None, - key: torch.Tensor | None = None, - value: torch.Tensor | None = None, - attention_scores: torch.Tensor | None = None, - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - """Apply sparsity to attention computation. + attention_scores: torch.Tensor, + sparse_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Apply sparsity mask to attention scores. Args: - query: Query tensor - key: Key tensor - value: Value tensor - attention_scores: Pre-computed attention scores + attention_scores: Pre-softmax attention scores [batch, heads, seq_q, seq_k] + sparse_mask: Optional pre-computed mask. If None, calculates internally. + + Returns: + Masked attention scores with sparse elements set to -inf + """ + + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for display/debugging. Returns: - Tuple of (query, key, value, attention_scores) with sparsity applied + Dictionary with threshold information. Should include: + - 'type': 'static', 'dynamic', or 'none' + - 'value': threshold value (for static) + - 'scale_factor': scale factor (for dynamic) + - Other method-specific info """ + return {"type": "none", "value": None} @property @abstractmethod diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py index 88434e746..b6b1e809f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -22,10 +22,12 @@ from modelopt.torch.opt.conversion import apply_mode from modelopt.torch.opt.searcher import ForwardLoop +from .calibration import calibrate_sparse_attention from .config import SparseAttentionConfig from .mode import SparseAttentionModeRegistry __all__ = [ + "calibrate", "sparsify", ] @@ -58,12 +60,36 @@ def sparsify( .. code-block::python config = { - "method": "flash_skip_softmax", "sparse_cfg": { + # Phase-aware thresholds with backend selection "*attention*": { + "method": "flash_skip_softmax", "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + # Disable for specific layers + "*layer.0*": {"enable": False}, + # Default settings + "default": {"enable": False}, + }, + } + + For automatic threshold calibration using RULER dataset: + + .. code-block::python + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", "backend": "pytorch", "enable": True, + "calibration": { # Enables automatic threshold calibration + "target_sparse_ratio": 0.5, + "samples": 48, + "max_seqlen": 8192, + }, }, "default": {"enable": False}, }, @@ -110,7 +136,7 @@ def forward_loop(model) -> float: from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained(b model_path, attn_implementation="eager", # Required for sparse attention torch_dtype=torch.bfloat16, @@ -126,4 +152,26 @@ def forward_loop(model) -> float: model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry ) + # Calibrate the sparsity ratio of the attention modules + return calibrate(model, config, forward_loop=forward_loop) + + +def calibrate( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Calibrates sparse attention thresholds based on target sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration with calibration settings + forward_loop: Optional callable that forwards calibration data through the model. + If provided, uses this for calibration data. + If None, will auto-generate RULER dataset for calibration. + + Returns: + The calibrated model with optimized sparse attention thresholds. + """ + calibrate_sparse_attention(model, config, forward_loop=forward_loop) return model diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 16b08bf19..29a73218b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -15,6 +15,8 @@ """Extensible sparse attention module.""" +from typing import Any + import torch import torch.nn.functional as F @@ -23,6 +25,7 @@ from .config import SparseAttentionAttributeConfig from .methods import get_sparse_method +from .stats_manager import SparseAttentionStatsManager class SparseAttentionModule(DynamicModule): @@ -103,6 +106,17 @@ def set_from_attribute_config( # Initialize sparse method instance self._init_sparse_method() + # Create stats manager based on config + if self._method_config.get("collect_stats", False): + self._stats_manager = SparseAttentionStatsManager( + module_name="sparse_attention", enabled=True + ) + else: + self._stats_manager = None + + # Initialize stats storage for collecting stats from sparse_softmax + self._last_stats: dict | None = None + def _init_sparse_method(self): """Initialize the sparse method instance.""" method_class = get_sparse_method(self._method) @@ -129,11 +143,22 @@ def get_stats(self) -> dict: Returns: Dictionary with sparsity statistics including 'average_sparsity' if available. - Returns empty dict (statistics collection will be added in calibration PR). + Returns empty dict if stats manager is not enabled. """ - # TODO: Statistics collection will be added in calibration PR + if self._stats_manager is not None and self._stats_manager.enabled: + return self._stats_manager.get_summary() return {} + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information from the sparse method instance. + + Returns: + Dictionary with threshold information from the sparse method. + """ + if hasattr(self, "_sparse_method_instance") and self._sparse_method_instance is not None: + return self._sparse_method_instance.get_threshold_info() + return {"type": "none", "value": None} + def _setup(self): """Setup called by DynamicModule.""" # Apply default configuration if not yet configured @@ -157,6 +182,11 @@ def forward(self, *args, **kwargs): with context: result = super().forward(*args, **kwargs) + # Collect stats if manager is available + if self._stats_manager is not None and self._last_stats is not None: + self._stats_manager.collect(self._last_stats) + self._last_stats = None # Clear after collection + return result def _get_sparse_context(self): @@ -172,14 +202,12 @@ def _create_sparse_softmax(self): original_softmax = F.softmax def sparse_softmax(input, dim=-1, *args, **kwargs): - # Let the method handle the sparsification - _, _, _, sparse_input = self._sparse_method_instance.apply_sparsity( - None, None, None, input - ) + # Calculate sparsity mask and collect statistics + sparse_mask, stats = self._sparse_method_instance.calculate_sparsity(input) + + # Store stats for collection + self._last_stats = stats - # Use sparse input if modified, otherwise use original - if sparse_input is not None: - return original_softmax(sparse_input, dim, *args, **kwargs) return original_softmax(input, dim, *args, **kwargs) return sparse_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py new file mode 100644 index 000000000..9fc57a0b1 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Statistics manager for sparse attention modules.""" + + +class SparseAttentionStatsManager: + """Centralized statistics manager for sparse attention. + + This class is the single source of truth for all statistics collection + in sparse attention modules. It handles both runtime aggregation and + per-sample calibration statistics. + + Design principles: + - Single responsibility: only stats management + - No computation: receives pre-computed stats from methods + - Optional: can be None if stats collection disabled + - Zero overhead when disabled + """ + + def __init__(self, module_name: str, enabled: bool = True): + """Initialize stats manager. + + Args: + module_name: Name of the module this manager is attached to + enabled: Whether stats collection is enabled + """ + self.module_name = module_name + self.enabled = enabled + self.calibration_mode = False + + # Aggregated stats (running totals across all forward passes) + self.aggregated_stats: dict = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + + # Per-sample stats (only populated during calibration) + self.per_sample_stats: list[dict] = [] + + def collect(self, stats: dict): + """Collect statistics from a single forward pass. + + Args: + stats: Dictionary containing statistics from method computation. + Expected keys: sparsity, phase, total_blocks, sparse_blocks, + sample_length (optional) + """ + if not self.enabled: + return + + # Update aggregated stats + self.aggregated_stats["total_calls"] += 1 + self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) + self.aggregated_stats["sparse_blocks"] += stats.get("sparse_blocks", 0) + + phase = stats.get("phase", "unknown") + if phase in self.aggregated_stats["phase_counts"]: + self.aggregated_stats["phase_counts"][phase] += 1 + + # In calibration mode, store per-sample stats + if self.calibration_mode: + self.per_sample_stats.append( + { + "module": self.module_name, + "sparsity": stats.get("sparsity", 0.0), + "sample_length": stats.get("sample_length", 0), + "phase": phase, + } + ) + + def get_summary(self) -> dict: + """Get aggregated statistics summary. + + Returns: + Dictionary with module name, total calls, average sparsity, + and phase distribution. + """ + total_blocks = self.aggregated_stats["total_blocks"] + if total_blocks > 0: + avg_sparsity = self.aggregated_stats["sparse_blocks"] / total_blocks + else: + avg_sparsity = 0.0 + + return { + "module": self.module_name, + "total_calls": self.aggregated_stats["total_calls"], + "average_sparsity": avg_sparsity, + "phase_distribution": self.aggregated_stats["phase_counts"].copy(), + } + + def set_calibration_mode(self, enabled: bool, reset_history: bool = True): + """Enable or disable calibration mode. + + In calibration mode, per-sample statistics are stored for detailed + analysis. Otherwise, only aggregated stats are maintained. + + Args: + enabled: Whether to enable calibration mode + reset_history: Whether to clear per_sample_stats when enabling + """ + self.calibration_mode = enabled + if enabled and reset_history: + self.per_sample_stats = [] + + def reset(self): + """Reset all statistics to initial state.""" + self.aggregated_stats = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + self.per_sample_stats = [] + + def get_calibration_stats(self) -> list[dict]: + """Get per-sample calibration statistics. + + Returns: + List of per-sample statistics dictionaries. + Empty list if not in calibration mode. + """ + return self.per_sample_stats diff --git a/setup.py b/setup.py index 242505302..a87a7e93e 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,8 @@ "torch-geometric", "tox>4.18", "tox-current-env>=0.0.12", + "nltk", + "wonderwords", ], # docs "dev-docs": [ diff --git a/tests/_test_utils/torch/sparsity/sparse_attention_common.py b/tests/_test_utils/torch/sparsity/sparse_attention_common.py index 7724908b0..5ed079966 100644 --- a/tests/_test_utils/torch/sparsity/sparse_attention_common.py +++ b/tests/_test_utils/torch/sparsity/sparse_attention_common.py @@ -153,13 +153,15 @@ def forward_loop(model): with torch.no_grad(): for batch in calib_data: output = model(batch) - assert not torch.isnan(output).any(), "NaN in output" - assert output is not None, "Output is None" + assert not torch.isnan(output).any(), ( + f"NaN detected in output for batch shape {batch.shape}" + ) + assert output is not None, f"Output is None for batch shape {batch.shape}" return model -def save_restore_test(model_cls, device, sparse_config): +def save_restore_test(model_cls, device, sparse_config, atol=1e-6): """Test save and restore of sparse attention state. Args: @@ -190,6 +192,6 @@ def save_restore_test(model_cls, device, sparse_config): output_sparse = model_sparse(test_input) output_restored = model_restored(test_input) - assert torch.allclose(output_sparse, output_restored, atol=1e-6), ( + assert torch.allclose(output_sparse, output_restored, atol), ( "Restored model output doesn't match original" ) diff --git a/tests/examples/llm_eval/test_llm_eval.py b/tests/examples/llm_eval/test_llm_eval.py index 0abf78b53..88d29dedc 100644 --- a/tests/examples/llm_eval/test_llm_eval.py +++ b/tests/examples/llm_eval/test_llm_eval.py @@ -36,3 +36,20 @@ def test_llama_eval_fp8(): finally: # Force kill llm-serve if it's still running subprocess.run(["pkill", "-f", "llm-serve"], check=False) + + +def test_llama_eval_sparse_attention(tiny_llama_path): + """Test sparse attention with llm_eval integration.""" + try: + # Test with default sparse attention config (no quantization) + run_llm_ptq_command( + model=tiny_llama_path, + quant="none", # No quantization, only sparse attention + tasks="lm_eval", + lm_eval_tasks="hellaswag", + lm_eval_limit=0.05, # Small limit for fast test + sparse_cfg="SKIP_SOFTMAX_DEFAULT", + batch=4, + ) + finally: + subprocess.run(["pkill", "-f", "llm-serve"], check=False) diff --git a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py index b70dfab35..9f1cb8125 100644 --- a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py +++ b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py @@ -34,7 +34,6 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", } ) kwargs.setdefault("seq_len", 128) - kwargs.setdefault("num_samples", 1) kwargs.setdefault("max_new_tokens", 16) cmd_parts = extend_cmd_parts(["python", "hf_sa.py"], **kwargs) @@ -43,8 +42,10 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", @pytest.mark.parametrize("method", ["skip_softmax"]) def test_attention_sparsity(tiny_llama_path, tmp_path, method): - """Test sparse attention with TinyLlama.""" + """Test sparse attention with TinyLlama (with and without calibration).""" run_attention_sparsity_command( model=tiny_llama_path, method=method, + seq_len=128, + max_new_tokens=10, ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py new file mode 100644 index 000000000..913dc24a0 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for sparse attention calibration.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import SimpleTransformerEncoderLayer + +import modelopt.torch.opt as mto +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import RulerDatasetBuilder +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# Skip all tests if no GPU available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") + + +class TestRulerDatasetBuilderGPU: + """Test RULER dataset generation with real tokenizers on GPU.""" + + def test_ruler_generation_with_real_tokenizer(self): + """Test RULER generation with GPT2 tokenizer.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 samples (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 6 samples (1 per task) + assert len(dataset) == 6 + + # All samples should have valid structure + for sample in dataset: + assert "input" in sample + assert "length" in sample + assert sample["length"] > 0 + + def test_generated_length_accuracy(self): + """Test that generated token counts are accurate.""" + builder = RulerDatasetBuilder( + samples=3, + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check that lengths are within reasonable range of target + for sample in dataset: + # RULER aims for 70-90% of target for context + assert 700 < sample["length"] < 1400 + + def test_multiple_subtasks(self): + """Test generation with multiple RULER subtasks.""" + builder = RulerDatasetBuilder( + samples=12, # Need at least 6 for 1 per task, use 12 for 2 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check task distribution (should have multiple tasks from RULER_TASKS) + tasks_found = {s["task"] for s in dataset} + assert len(tasks_found) >= 2 # At least 2 different tasks + + def test_large_context_lengths(self): + """Test with larger context lengths.""" + builder = RulerDatasetBuilder( + samples=24, # 4 lengths * 6 tasks = need 24 for 1 per (length, task) + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + assert len(dataset) == 24 + + # Verify we have different lengths + lengths = [s["length"] for s in dataset] + # Should have variety of lengths across the bins + assert len(set(lengths)) > 1 # At least 2 different target lengths used + + +class TestCalibrationGPU: + """Test calibration with real models on GPU.""" + + @pytest.fixture + def simple_model(self): + """Create simple attention model for testing.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibration_simple_model(self, simple_model): + """Test calibration with simple attention model.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + # Simple forward loop for calibration + pass + + # Apply sparse attention with calibration + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules exist + sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(sparse_modules) > 0 + + # Verify calibration was applied + for module in sparse_modules: + method = module._sparse_method_instance + # Check if calibrated threshold scale factor is set + if hasattr(method, "threshold_scale_factor") and method.threshold_scale_factor: + assert method.threshold_scale_factor > 0 + + def test_calibration_pytorch_backend(self, simple_model): + """Test calibration with pytorch backend.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Check backend is set correctly + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert hasattr(method, "backend") + assert method.backend == "pytorch" + + def test_simplified_calibration(self, simple_model): + """Test simplified calibration (prefill phase only).""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Should complete without errors + assert sparse_model is not None + + def test_calibration_persistence(self, simple_model): + """Test save and restore of calibrated model.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Save modelopt state + modelopt_state = mto.modelopt_state(sparse_model) + + # Create new model and restore + model_restored = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + + restored = mto.restore_from_modelopt_state(model_restored, modelopt_state) + + # Check that sparse attention is restored + has_sparse = any(isinstance(m, SparseAttentionModule) for m in restored.modules()) + assert has_sparse + + +class TestCalibrationEndToEnd: + """Integration tests with inference.""" + + @pytest.fixture + def simple_model_setup(self): + """Setup simple model.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibrated_model_inference(self, simple_model_setup): + """Test inference with calibrated model.""" + model = simple_model_setup + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Test inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + sparse_model.eval() + with torch.no_grad(): + output = sparse_model(test_input) + + # Check output is valid + assert output is not None + assert not torch.isnan(output).any() + + def test_calibrated_vs_fixed_threshold(self, simple_model_setup): + """Compare calibrated vs fixed threshold models.""" + # Config with calibration + config_calibrated = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + # Config with fixed threshold (no calibration) + config_fixed = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + } + }, + } + + def forward_loop(model): + pass + + # Test both can be created + model_calibrated = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_calibrated, + forward_loop=forward_loop, + ) + + model_fixed = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_fixed, + ) + + # Both should work for inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + with torch.no_grad(): + output_calibrated = model_calibrated(test_input) + output_fixed = model_fixed(test_input) + + assert output_calibrated is not None + assert output_fixed is not None + + def test_memory_usage(self, simple_model_setup): + """Test that calibration doesn't cause memory issues.""" + model = simple_model_setup + + # Clear cache before test + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate + sparsify(model, config, forward_loop=forward_loop) + + # Check memory didn't explode + final_memory = torch.cuda.memory_allocated() + memory_increase = final_memory - initial_memory + + # Memory should be reasonable (not more than 2GB increase) + assert memory_increase < 2 * 1024**3 # 2GB + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py index b487d8639..d9bbee157 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py @@ -248,8 +248,8 @@ def test_causal_vs_noncausal(self): assert stats_causal["total_blocks"] == 3 assert stats_noncausal["total_blocks"] == 4 - def test_apply_sparsity_assertions(self): - """Test apply_sparsity input validation.""" + def test_calculate_sparsity_assertions(self): + """Test calculate_sparsity input validation.""" method = FlashSkipSoftmax( { "threshold": 1e-3, @@ -260,13 +260,56 @@ def test_apply_sparsity_assertions(self): } ) - # Test: attention_scores required - with pytest.raises(AssertionError, match="attention_scores must be provided"): - method.apply_sparsity() - # Test: 4D shape required with pytest.raises(AssertionError, match="Expected 4D"): - method.apply_sparsity(attention_scores=torch.randn(2, 64, 64)) # 3D + method.calculate_sparsity(attention_scores=torch.randn(2, 64, 64)) # 3D + + def test_apply_sparsity_with_mask(self): + """Test apply_sparsity with pre-computed mask.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(2, 4, 128, 256) + + # Calculate sparsity first + sparse_mask, stats = method.calculate_sparsity(attn) + + # Apply sparsity with pre-computed mask + sparse_attn = method.apply_sparsity(attn, sparse_mask) + + # Verify output shape matches input + assert sparse_attn.shape == attn.shape + + # Verify masked positions have min value + mask_value = torch.finfo(attn.dtype).min + assert (sparse_attn[~sparse_mask] == mask_value).all() + + def test_apply_sparsity_without_mask(self): + """Test apply_sparsity calculates mask internally when None.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(2, 4, 128, 256) + + # Apply sparsity without pre-computed mask + sparse_attn = method.apply_sparsity(attn) + + # Verify output shape matches input + assert sparse_attn.shape == attn.shape def test_name_property(self): """Test method name property.""" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py new file mode 100644 index 000000000..4558ca22b --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -0,0 +1,623 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention calibration.""" + +import pytest + +pytest.importorskip("transformers") + +import numpy as np +from _test_utils.torch_sparsity.sparse_attention_common import ( + SimpleAttentionModel, + SimpleTransformerEncoder, +) +from pydantic import ValidationError + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import ( + DynamicThresholdCalibrator, + RulerDatasetBuilder, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.calibrate import ( + _extract_calibration_config, + calibrate_sparse_attention, + create_calibration_forward_loop, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.dataset import _generate_target_lengths +from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestLengthGeneration: + """Test automatic target length generation.""" + + def test_generate_target_lengths_default(self): + """Test default 4 bins generation.""" + lengths = _generate_target_lengths(32768, num_length_bins=4) + assert lengths == [32768, 16384, 8192, 4096] + + def test_generate_target_lengths_stops_at_minimum(self): + """Test generation stops at minimum threshold.""" + lengths = _generate_target_lengths(2048, num_length_bins=4) + assert lengths == [2048, 1024] # Stops at 1024 + + def test_generate_target_lengths_fewer_bins(self): + """Test with fewer bins.""" + lengths = _generate_target_lengths(16384, num_length_bins=2) + assert lengths == [16384, 8192] + + def test_generate_target_lengths_more_bins(self): + """Test with more bins.""" + lengths = _generate_target_lengths(65536, num_length_bins=6) + assert lengths == [65536, 32768, 16384, 8192, 4096, 2048] + + def test_generate_target_lengths_exactly_minimum(self): + """Test when max_seqlen equals minimum.""" + lengths = _generate_target_lengths(1024, num_length_bins=4) + assert lengths == [1024] + + +class TestRulerDatasetBuilder: + """Test RULER dataset generation without requiring real tokenizers.""" + + def test_builder_initialization(self): + """Test that builder initializes correctly.""" + builder = RulerDatasetBuilder( + samples=12, + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + assert builder.total_samples == 12 + assert builder.max_seqlen == 2048 + assert builder.target_lengths == [2048, 1024] + assert builder.samples_per_length == [6, 6] # Evenly distributed + assert len(builder.subtasks) == 6 # All RULER_TASKS + assert builder.seed == 42 + + def test_builder_initialization_invalid_config(self): + """Test that builder raises error for invalid inputs.""" + # Test invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + RulerDatasetBuilder( + samples=0, + max_seqlen=2048, + tokenizer_name_or_path="gpt2", + ) + + # Test max_seqlen below minimum + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + RulerDatasetBuilder( + samples=4, + max_seqlen=512, # Below minimum + tokenizer_name_or_path="gpt2", + ) + + def test_dataset_generation_minimal(self): + """Test generating small dataset.""" + builder = RulerDatasetBuilder( + samples=12, # 6 tasks x 2 lengths = need 12 for 1 per task per length + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 12 samples (6 tasks x 1 sample per task x 2 lengths) + assert len(dataset) == 12 + assert all(isinstance(sample, dict) for sample in dataset) + + def test_dataset_structure(self): + """Test that dataset has correct structure.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + sample = dataset[0] + + # Check required fields + assert "input" in sample + assert "length" in sample + assert "task" in sample + assert "target_length" in sample + + # Check field types + assert isinstance(sample["input"], str) + assert isinstance(sample["length"], int) + assert isinstance(sample["task"], str) + assert sample["length"] > 0 + + def test_sample_distribution(self): + """Test that samples are distributed across lengths and subtasks.""" + builder = RulerDatasetBuilder( + samples=24, # 6 tasks x 2 lengths x 2 samples = 24 + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should have 24 samples (12 per length, 2 per task) + assert len(dataset) == 24 + + # Check task distribution (should have variety from all RULER_TASKS) + tasks = [s["task"] for s in dataset] + # Verify we have all 6 tasks represented + assert len(set(tasks)) == 6 + + def test_length_targeting(self): + """Test that generated lengths are close to targets.""" + builder = RulerDatasetBuilder( + samples=6, # 1 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Lengths should be within reasonable range of target + # RULER aims for 70-90% of target length for context + for sample in dataset: + assert 700 < sample["length"] < 1400 # Reasonable range around 1024 + + def test_uneven_sample_distribution(self): + """Test that samples are distributed evenly (remainder dropped).""" + builder = RulerDatasetBuilder( + samples=50, # 50 samples across 4 lengths + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + # Even distribution: 50//4 = 12 per length + assert builder.total_samples == 50 + assert builder.target_lengths == [8192, 4096, 2048, 1024] + assert builder.samples_per_length == [12, 12, 12, 12] + assert sum(builder.samples_per_length) == 48 # 2 samples dropped (remainder) + + # Actual generated samples: 12//6=2 per task, 4 lengths, 6 tasks + # Total: 2 x 6 x 4 = 48 + dataset = builder.build_calibration_dataset() + assert len(dataset) == 48 + + +class TestDynamicThresholdCalibrator: + """Test calibration algorithm correctness.""" + + def test_calibrator_initialization(self): + """Test that calibrator initializes correctly.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + threshold_trials=[1e-4, 1e-3, 1e-2], + ) + + assert calibrator.target_sparse_ratio == 0.5 + assert len(calibrator.threshold_trials) == 3 + + def test_calibrator_default_threshold_trials(self): + """Test that calibrator has default threshold trials.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + ) + + # Should have default threshold trials + assert calibrator.threshold_trials is not None + assert len(calibrator.threshold_trials) == 12 + # Check they are positive and in valid range + trials = calibrator.threshold_trials + assert all(0 < t < 1 for t in trials) + + def test_regression_calculation_synthetic(self): + """Test 'a' parameter calculation with synthetic data.""" + # Create synthetic optimal pairs + # If threshold = a / length, then with perfect data: + # length=1000, threshold=10 => a=10000 + # length=2000, threshold=5 => a=10000 + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0, "achieved_sparsity": 0.5}, + {"length": 2000, "optimal_threshold": 5.0, "achieved_sparsity": 0.5}, + {"length": 4000, "optimal_threshold": 2.5, "achieved_sparsity": 0.5}, + ] + + # Manual regression calculation + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + # Calculate 'a' using least squares + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should be close to 10000 + assert 9500 < a_parameter < 10500 + + # Test individual 'a' values + a_per_sample = y * lengths + assert np.allclose(a_per_sample, 10000, rtol=0.05) + + def test_multiple_samples_different_lengths(self): + """Test regression with varied lengths.""" + # More realistic scenario with some variance + optimal_pairs = [ + {"length": 500, "optimal_threshold": 20.0, "achieved_sparsity": 0.5}, + {"length": 1000, "optimal_threshold": 10.5, "achieved_sparsity": 0.51}, + {"length": 2000, "optimal_threshold": 5.2, "achieved_sparsity": 0.49}, + {"length": 4000, "optimal_threshold": 2.4, "achieved_sparsity": 0.50}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should still be around 10000 with some tolerance for variance + assert 9000 < a_parameter < 11000 + + def test_r_squared_calculation(self): + """Test R-squared calculation for regression quality.""" + # Perfect fit data + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0}, + {"length": 2000, "optimal_threshold": 5.0}, + {"length": 4000, "optimal_threshold": 2.5}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Calculate R-squared + y_pred = a_parameter * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Perfect fit should have R^2 close to 1 + assert r_squared > 0.99 + + +class TestCalibrationIntegration: + """Test end-to-end calibration without GPU.""" + + def test_calibration_disabled(self): + """Test that no calibration occurs when disabled.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # No forward_loop needed when calibration disabled + sparse_model = sparsify(model, config) + + # Check that sparse attention is applied but not calibrated + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + # Check that no calibration is set + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert not getattr(method, "threshold_scale_factor", None) + + def test_sparsify_with_calibration_requires_forward_loop(self): + """Test that calibration requires forward_loop or proper model config.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + }, + }, + } + + # Without forward_loop and without model.config._name_or_path, should raise ValueError + with pytest.raises(ValueError, match="Could not load tokenizer"): + sparsify(model, config, forward_loop=None) + + def test_multiple_sparse_modules(self): + """Test that calibration handles multiple attention layers.""" + model = SimpleTransformerEncoder() + + config = { + "sparse_cfg": {"*attn*": {"threshold": 1e-3, "br": 64, "bc": 64, "enable": True}}, + } + + sparse_model = sparsify(model, config) + + # Count sparse attention modules + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + + # Should have 2 sparse attention modules + assert sparse_count == 2 + + def test_calibration_config_validation(self): + """Test CalibrationConfig validation.""" + # Valid config + config = CalibrationConfig( + target_sparse_ratio=0.5, + samples=48, + max_seqlen=32768, + ) + assert config.target_sparse_ratio == 0.5 + assert config.samples == 48 + assert config.max_seqlen == 32768 + + # Invalid target_sparse_ratio (> 1.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=1.5, samples=48, max_seqlen=32768) + + # Invalid target_sparse_ratio (< 0.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=-0.1, samples=48, max_seqlen=32768) + + # Invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + CalibrationConfig(target_sparse_ratio=0.5, samples=0, max_seqlen=32768) + + # Invalid max_seqlen + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + CalibrationConfig(target_sparse_ratio=0.5, samples=48, max_seqlen=512) + + def test_threshold_trials_validation(self): + """Test threshold_trials validation.""" + # Valid custom threshold_trials + config = CalibrationConfig( + target_sparse_ratio=0.5, + threshold_trials=[1e-5, 1e-4, 1e-3, 1e-2], + ) + assert config.threshold_trials == [1e-5, 1e-4, 1e-3, 1e-2] + + # None (use defaults) + config_default = CalibrationConfig(target_sparse_ratio=0.5) + assert config_default.threshold_trials is None + + # Invalid: empty list + with pytest.raises(ValueError, match="threshold_trials must not be empty"): + CalibrationConfig(threshold_trials=[]) + + # Invalid: threshold out of range (>= 1.0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 1.0]) + + # Invalid: threshold out of range (<= 0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 0]) + + # Invalid: not a list (Pydantic raises ValidationError, not ValueError) + with pytest.raises(ValidationError, match="Input should be a valid list"): + CalibrationConfig(threshold_trials=1e-4) + + +class TestDynamicThresholdCalibratorMethods: + """Test individual methods of DynamicThresholdCalibrator.""" + + def test_set_threshold(self): + """Test _set_threshold method.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + # Get sparse modules + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(modules) > 0 + + # Create calibrator and set threshold + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + calibrator._set_threshold(modules, 0.05) + + # Verify threshold was set + for module in modules: + assert module._sparse_method_instance.threshold == 0.05 + + def test_enable_disable_calibration_mode(self): + """Test _enable_calibration_mode and _disable_calibration_mode.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + + # Enable calibration mode + calibrator._enable_calibration_mode(modules) + + for module in modules: + assert module._stats_manager is not None + assert module._stats_manager.enabled is True + assert module._stats_manager.calibration_mode is True + assert module._sparse_method_instance._calibration_mode is True + + # Disable calibration mode + calibrator._disable_calibration_mode(modules) + + for module in modules: + assert module._stats_manager.calibration_mode is False + assert module._sparse_method_instance._calibration_mode is False + + def test_extract_calibration_stats_no_stats(self): + """Test _extract_calibration_stats when no stats collected.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + + # Extract stats without running any forward passes + stats = calibrator._extract_calibration_stats(modules) + + # Should return empty list + assert stats == [] + + def test_calibrator_with_single_sample(self): + """Test calibrator edge case with only one sample.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + threshold_trials=[0.001, 0.01, 0.1], + ) + + # Even with one sample, regression should work + assert calibrator.target_sparse_ratio == 0.5 + assert len(calibrator.threshold_trials) == 3 + + +class TestCalibrateFunction: + """Test calibrate_sparse_attention function.""" + + def test_calibrate_no_config(self): + """Test calibration when config has no calibration section.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + # Config without calibration + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # Should return empty dict when no calibration config + result = calibrate_sparse_attention(model, config) + + assert result == {} + + def test_extract_calibration_config(self): + """Test _extract_calibration_config function.""" + # Config with calibration + config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.3, + "samples": 12, + "max_seqlen": 2048, + }, + "*attn*": { + "method": "flash_skip_softmax", + }, + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is not None + assert calib_config.target_sparse_ratio == 0.3 + assert calib_config.samples == 12 + assert calib_config.max_seqlen == 2048 + + def test_extract_calibration_config_none(self): + """Test _extract_calibration_config when no calibration.""" + # Config without calibration + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + } + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is None + + def test_create_calibration_forward_loop(self): + """Test create_calibration_forward_loop function.""" + calibration_data = [ + {"input": "This is a test sample.", "length": 512}, + {"input": "Another test sample.", "length": 1024}, + ] + + forward_loop = create_calibration_forward_loop( + calibration_data=calibration_data, + tokenizer_name_or_path="gpt2", + ) + + # Should return a callable + assert callable(forward_loop) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py new file mode 100644 index 000000000..1824825f9 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sparse attention configuration validation.""" + +import pytest +from pydantic import ValidationError + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_DEFAULT, + FlashSkipSoftmaxConfig, + SparseAttentionAttributeConfig, + SparseAttentionConfig, +) + + +class TestSparseAttentionAttributeConfig: + """Test SparseAttentionAttributeConfig validators.""" + + def test_valid_config(self): + """Test creating valid config.""" + config = SparseAttentionAttributeConfig( + method="flash_skip_softmax", + threshold=1e-4, + br=128, + bc=128, + enable=True, + ) + assert config.method == "flash_skip_softmax" + assert config.threshold == 1e-4 + assert config.br == 128 + assert config.bc == 128 + + def test_method_validation(self): + """Test method must be string.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + SparseAttentionAttributeConfig(method=123) + + def test_block_size_validation_negative(self): + """Test block sizes must be positive.""" + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(br=-1) + + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(bc=0) + + def test_block_size_validation_large(self): + """Test that large block sizes are accepted.""" + # Large block sizes are allowed (warning removed for simplicity) + config = SparseAttentionAttributeConfig(br=2048) + assert config.br == 2048 + + def test_threshold_validation_range(self): + """Test threshold must be in range (0, 1).""" + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=-0.1) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.5) + + def test_threshold_validation_dict(self): + """Test threshold dict validation.""" + # Valid phase-aware threshold + config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) + assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + + # Invalid phase key + with pytest.raises(ValidationError, match="Invalid threshold phases"): + SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + + # Invalid threshold value in dict (negative) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + + # Invalid threshold value in dict (>= 1.0) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + + def test_threshold_validation_type(self): + """Test threshold type validation.""" + with pytest.raises(ValidationError, match="Input should be a valid"): + SparseAttentionAttributeConfig(threshold="invalid") + + +class TestSparseAttentionConfig: + """Test SparseAttentionConfig.""" + + def test_default_config(self): + """Test default configuration.""" + config = SparseAttentionConfig() + assert "sparse_cfg" in config.model_dump() + # Check default pattern has method + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" + + def test_predefined_config(self): + """Test pre-defined configuration.""" + assert "sparse_cfg" in SKIP_SOFTMAX_DEFAULT + assert "method" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"]["*attn*"] + assert "*attn*" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"] + + +class TestFlashSkipSoftmaxConfig: + """Test FlashSkipSoftmaxConfig.""" + + def test_default_values(self): + """Test default values for flash_skip_softmax config.""" + config = FlashSkipSoftmaxConfig() + assert "*attention*" in config.sparse_cfg + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 6fcad9bb8..de954ae97 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.conversion import ( disable_sparse_attention, enable_sparse_attention, + print_sparse_attention_summary, ) from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule @@ -170,6 +171,19 @@ def test_disable_enable_functions(self): if isinstance(module, SparseAttentionModule): assert module.is_enabled + def test_print_sparse_attention_summary(self, capsys): + """Test print_sparse_attention_summary function.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Print summary + print_sparse_attention_summary(model) + + # Capture output + captured = capsys.readouterr() + assert "Total sparse attention modules:" in captured.out + assert "Enabled:" in captured.out + def test_restore_sparse_attention_model(self): """Test save/restore via modelopt_state.""" # Create and sparsify original model @@ -192,3 +206,100 @@ def test_restore_sparse_attention_model(self): if isinstance(module, SparseAttentionModule): assert hasattr(module, "_method") assert module._method == "flash_skip_softmax" + + +class TestSparseAttentionModuleMethods: + """Test SparseAttentionModule methods.""" + + def test_get_stats_with_stats_manager(self): + """Test get_stats() when stats manager exists and is enabled.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "collect_stats": True, # Enable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + assert sparse_module._stats_manager is not None + + # Get stats (should return summary) + stats = sparse_module.get_stats() + + assert isinstance(stats, dict) + assert "module" in stats + assert "total_calls" in stats + assert "average_sparsity" in stats + + def test_get_stats_without_stats_manager(self): + """Test get_stats() when stats manager is None.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "collect_stats": False, # Disable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Stats manager should be None + assert module._stats_manager is None + + # get_stats should return empty dict + stats = module.get_stats() + assert stats == {} + break + + def test_get_threshold_info(self): + """Test get_threshold_info() method.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.005, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module and test threshold info + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + info = module.get_threshold_info() + + assert isinstance(info, dict) + assert "type" in info + assert info["type"] == "static" + assert info["value"] == 0.005 + break diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py new file mode 100644 index 000000000..02188e97a --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SparseAttentionStatsManager.""" + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager + + +class TestStatsManagerInitialization: + """Test stats manager initialization.""" + + def test_initialization_defaults(self): + """Test default initialization.""" + manager = SparseAttentionStatsManager(module_name="test_module") + + assert manager.module_name == "test_module" + assert manager.enabled is True + assert manager.calibration_mode is False + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + + def test_initialization_disabled(self): + """Test initialization with disabled stats.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=False) + + assert manager.enabled is False + assert manager.calibration_mode is False + + def test_initialization_custom_name(self): + """Test initialization with custom module name.""" + manager = SparseAttentionStatsManager(module_name="custom.attention.module") + + assert manager.module_name == "custom.attention.module" + + +class TestStatsCollection: + """Test statistics collection functionality.""" + + def test_collect_stats_enabled(self): + """Test collecting stats when enabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 1 + assert manager.aggregated_stats["total_blocks"] == 100 + assert manager.aggregated_stats["sparse_blocks"] == 50 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 1 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 + + def test_collect_stats_disabled(self): + """Test that collect() is no-op when disabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=False) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + + manager.collect(stats) + + # Should remain at initial values + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + + def test_collect_multiple_calls(self): + """Test accumulation over multiple collect calls.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect multiple times + for i in range(5): + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 5 + assert manager.aggregated_stats["total_blocks"] == 500 + assert manager.aggregated_stats["sparse_blocks"] == 250 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 5 + + def test_collect_different_phases(self): + """Test phase counting.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect prefill stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + # Collect decode stats + manager.collect({"phase": "decode", "total_blocks": 10, "sparse_blocks": 5}) + + assert manager.aggregated_stats["phase_counts"]["prefill"] == 2 + assert manager.aggregated_stats["phase_counts"]["decode"] == 1 + assert manager.aggregated_stats["phase_counts"]["unknown"] == 0 + + +class TestCalibrationMode: + """Test calibration mode functionality.""" + + def test_calibration_mode_per_sample_collection(self): + """Test that calibration mode stores per-sample stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Enable calibration mode + manager.set_calibration_mode(enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + # Should store in per_sample_stats + assert len(manager.per_sample_stats) == 1 + assert manager.per_sample_stats[0]["module"] == "test" + assert manager.per_sample_stats[0]["sparsity"] == 0.5 + assert manager.per_sample_stats[0]["sample_length"] == 1024 + assert manager.per_sample_stats[0]["phase"] == "prefill" + + def test_calibration_mode_off(self): + """Test that per-sample stats are not collected when calibration mode is off.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + # Calibration mode is off by default + + stats = {"sparsity": 0.5, "phase": "prefill", "total_blocks": 100, "sparse_blocks": 50} + + manager.collect(stats) + + # Should NOT store in per_sample_stats + assert len(manager.per_sample_stats) == 0 + + # But should still aggregate + assert manager.aggregated_stats["total_calls"] == 1 + + def test_set_calibration_mode_with_reset(self): + """Test set_calibration_mode with reset_history=True.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats in calibration mode + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Re-enable with reset + manager.set_calibration_mode(enabled=True, reset_history=True) + assert len(manager.per_sample_stats) == 0 # Should be cleared + + def test_set_calibration_mode_without_reset(self): + """Test set_calibration_mode with reset_history=False.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Disable without reset + manager.set_calibration_mode(enabled=False, reset_history=False) + assert len(manager.per_sample_stats) == 1 # Should be preserved + + +class TestGetSummary: + """Test get_summary() functionality.""" + + def test_get_summary_with_data(self): + """Test get_summary returns correct averages.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=True) + + # Collect stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 30}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + summary = manager.get_summary() + + assert summary["module"] == "test_module" + assert summary["total_calls"] == 2 + # Average sparsity: (30+50) / (100+100) = 80/200 = 0.4 + assert summary["average_sparsity"] == 0.4 + assert summary["phase_distribution"]["prefill"] == 2 + + def test_get_summary_no_data(self): + """Test get_summary with no collected data.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + summary = manager.get_summary() + + assert summary["module"] == "test" + assert summary["total_calls"] == 0 + assert summary["average_sparsity"] == 0.0 + assert summary["phase_distribution"]["prefill"] == 0 + + def test_get_summary_zero_blocks(self): + """Test get_summary when total_blocks is zero.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect stats with zero blocks + manager.collect({"phase": "prefill", "total_blocks": 0, "sparse_blocks": 0}) + + summary = manager.get_summary() + + assert summary["average_sparsity"] == 0.0 # Should handle division by zero + + +class TestGetCalibrationStats: + """Test get_calibration_stats() functionality.""" + + def test_get_calibration_stats(self): + """Test retrieving per-sample calibration stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect multiple samples + for i in range(3): + manager.collect( + { + "sparsity": 0.3 + i * 0.1, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 30, + "sample_length": 1024 + i * 512, + } + ) + + calib_stats = manager.get_calibration_stats() + + assert len(calib_stats) == 3 + assert calib_stats[0]["sparsity"] == 0.3 + assert calib_stats[1]["sparsity"] == 0.4 + assert calib_stats[2]["sparsity"] == 0.5 + + def test_get_calibration_stats_empty(self): + """Test get_calibration_stats when no calibration data.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + calib_stats = manager.get_calibration_stats() + + assert calib_stats == [] + + +class TestReset: + """Test reset functionality.""" + + def test_reset(self): + """Test reset() clears all statistics.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect some stats + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + manager.collect( + { + "sparsity": 0.3, + "phase": "decode", + "total_blocks": 10, + "sparse_blocks": 3, + "sample_length": 128, + } + ) + + # Verify stats exist + assert manager.aggregated_stats["total_calls"] == 2 + assert len(manager.per_sample_stats) == 2 + + # Reset + manager.reset() + + # All stats should be cleared + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + assert manager.aggregated_stats["phase_counts"]["prefill"] == 0 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py new file mode 100644 index 000000000..ac9f46a54 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for threshold calibration functionality.""" + +import pytest + +pytest.importorskip("transformers") + +from _test_utils.torch_sparsity.sparse_attention_common import SimpleAttentionModel + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestFlashSkipSoftmaxThresholdInfo: + """Test FlashSkipSoftmax.get_threshold_info() method.""" + + def test_static_threshold(self): + """Test threshold info for static threshold.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + assert info["type"] == "static" + assert info["value"] == 0.001 + + def test_phased_threshold(self): + """Test threshold info for phase-specific thresholds.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + assert info["type"] == "static_phased" + assert "thresholds" in info + assert info["thresholds"]["prefill"] == 0.001 + assert info["thresholds"]["decode"] == 0.0001 + assert "current" in info + + def test_dynamic_calibrated_threshold(self): + """Test threshold info for calibrated dynamic threshold.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Simulate calibration setting scale factor + method.threshold_scale_factor = 437.5 + + info = method.get_threshold_info() + + assert info["type"] == "dynamic" + assert info["scale_factor"] == 437.5 + assert info["formula"] == "λ / length" + assert "example_lengths" in info + assert abs(info["example_lengths"][1024] - 437.5 / 1024) < 1e-6 + assert abs(info["example_lengths"][2048] - 437.5 / 2048) < 1e-6 + + def test_threshold_info_structure(self): + """Test that threshold info has expected structure.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + # Should always have 'type' key + assert "type" in info + assert isinstance(info, dict) + + +class TestSparseAttentionModuleThresholdInfo: + """Test SparseAttentionModule.get_threshold_info() delegation.""" + + def test_module_delegates_to_method(self): + """Test that module correctly delegates to sparse method instance.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.005, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find sparse attention module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + + # Test get_threshold_info + info = sparse_module.get_threshold_info() + + assert info["type"] == "static" + assert info["value"] == 0.005 + + def test_module_with_calibrated_threshold(self): + """Test module reports calibrated threshold correctly.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module and set calibrated threshold + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + module._sparse_method_instance.threshold_scale_factor = 500.0 + break + + # Get threshold info + info = module.get_threshold_info() + + assert info["type"] == "dynamic" + assert info["scale_factor"] == 500.0 + + def test_module_without_method_instance(self): + """Test get_threshold_info when sparse method instance doesn't exist.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Remove sparse method instance to test fallback + delattr(module, "_sparse_method_instance") + + info = module.get_threshold_info() + + assert info["type"] == "none" + assert info["value"] is None + break + + +class TestPrintSparseAttentionSummaryIntegration: + """Test integration with print_sparse_attention_summary.""" + + def test_summary_displays_static_threshold(self, capsys): + """Test that print function displays static thresholds.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + print_sparse_attention_summary, + ) + + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + print_sparse_attention_summary(sparse_model) + + captured = capsys.readouterr() + assert "Static (1.00e-03)" in captured.out + assert "flash_skip_softmax" in captured.out + + def test_summary_displays_dynamic_threshold(self, capsys): + """Test that print function displays dynamic thresholds.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + print_sparse_attention_summary, + ) + + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Set calibrated threshold + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + module._sparse_method_instance.threshold_scale_factor = 437.5 + + print_sparse_attention_summary(sparse_model) + + captured = capsys.readouterr() + assert "Dynamic (λ=437.500000)" in captured.out + assert "flash_skip_softmax" in captured.out From 6486db4ea2d52d4c6745eb66da2b214e2feaa76b Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 15 Dec 2025 07:41:10 +0000 Subject: [PATCH 2/9] Add sparse attention calibration for the decode phase Signed-off-by: Kai Xu --- .../calibration/calibrate.py | 207 +++++++++++++++--- .../calibration/calibrator.py | 82 +++---- .../sparsity/attention_sparsity/config.py | 53 ++++- .../sparsity/attention_sparsity/conversion.py | 5 +- .../methods/flash_skip_softmax.py | 41 ++-- .../attention_sparsity/methods/registry.py | 3 + .../attention_sparsity/sparse_attention.py | 5 + .../attention_sparsity/stats_manager.py | 14 +- 8 files changed, 305 insertions(+), 105 deletions(-) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index 1b8f0e71b..bb6d511ab 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -79,6 +79,20 @@ def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | N return CalibrationConfig(**calib_dict) +def _parse_target_sparse_ratio( + target_sparse_ratio: dict[str, float], +) -> dict[str, float]: + """Parse target_sparse_ratio dict. + + Args: + target_sparse_ratio: Target sparsity ratio dict with 'prefill' and 'decode' keys + + Returns: + Dict with 'prefill' and 'decode' keys + """ + return target_sparse_ratio + + def create_calibration_forward_loop( calibration_data: list[dict[str, Any]], tokenizer_name_or_path: str, @@ -136,6 +150,73 @@ def forward_loop(model: nn.Module) -> None: return forward_loop +def create_decode_calibration_forward_loop( + calibration_data: list[dict[str, Any]], + tokenizer_name_or_path: str, + num_decode_tokens: int = 10, +) -> Callable: + """Create forward loop for decode phase calibration. + + Uses flash attention for fast prefill, then switches to eager attention + for decode token generation with softmax hook measurement. + + Args: + calibration_data: List of samples with 'input' and 'length' fields + tokenizer_name_or_path: HuggingFace tokenizer path + num_decode_tokens: Number of decode tokens to generate per sample + + Returns: + Forward loop function that takes model as argument + """ + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + def forward_loop(model: nn.Module) -> None: + device = next(model.parameters()).device + + for sample in calibration_data: + inputs = tokenizer( + sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] + ) + input_ids = inputs["input_ids"].to(device) + + # Save original attention implementation + original_attn_impl = getattr(model.config, "_attn_implementation", "eager") + + with torch.no_grad(): + # Step 1: Fast prefill with flash attention (no measurement) + model.config._attn_implementation = "flash_attention_2" + outputs = model(input_ids, use_cache=True) + past_key_values = outputs.past_key_values + + # Step 2: Switch to eager for decode (enables softmax hook) + model.config._attn_implementation = "eager" + + # Step 3: Manual decode loop for explicit control over token generation + # model.generate() method is not used here because it doesn't allow explicit control over KV cache + # Get the last token's logits and sample next token + next_token = outputs.logits[:, -1:, :].argmax(dim=-1) + + for _ in range(num_decode_tokens): + outputs = model( + next_token, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + next_token = outputs.logits[:, -1:, :].argmax(dim=-1) + + # Restore original attention implementation + model.config._attn_implementation = original_attn_impl + + # Clean up + del past_key_values + torch.cuda.empty_cache() + + return forward_loop + + def calibrate_sparse_attention( model: nn.Module, config: dict[str, Any], @@ -143,14 +224,16 @@ def calibrate_sparse_attention( ) -> dict[str, Any]: """Calibrate sparse attention parameters for optimal sparsity. + Supports both prefill and decode phase calibration with per-phase target sparsity. + Args: model: Model with sparse attention modules config: Sparse attention configuration dict forward_loop: Callable that forwards calibration data through model. - If None, auto-generates RULER dataset. + If None, auto-generates RULER dataset. Only used for prefill. Returns: - Dictionary with calibration results + Dictionary with calibration results for each phase """ # Extract and validate calibration config calib_config = _extract_calibration_config(config) @@ -159,9 +242,32 @@ def calibrate_sparse_attention( if calib_config is None: return {} - # Generate forward_loop if not provided - if not forward_loop: - tokenizer = _extract_tokenizer_from_model(model) + # Parse target_sparse_ratio into per-phase targets + target_dict = _parse_target_sparse_ratio(calib_config.target_sparse_ratio) + calibrate_prefill = target_dict.get("prefill", 0.0) > 0.0 + calibrate_decode = target_dict.get("decode", 0.0) > 0.0 + + # Skip if both phases are disabled + if not calibrate_prefill and not calibrate_decode: + print("Both prefill and decode target sparsity are 0.0, skipping calibration") + return {} + + # Get sparse attention modules + sparse_modules = [ + (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) + ] + + if not sparse_modules: + print("No sparse attention modules found for calibration") + return {} + + print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") + + # Extract tokenizer and build calibration data if needed + tokenizer = _extract_tokenizer_from_model(model) + calibration_data = None + + if calibrate_prefill or calibrate_decode: builder = RulerDatasetBuilder( samples=calib_config.samples, max_seqlen=calib_config.max_seqlen, @@ -171,41 +277,78 @@ def calibrate_sparse_attention( ) calibration_data = builder.build_calibration_dataset() print(f"Generated {len(calibration_data)} calibration samples") - forward_loop = create_calibration_forward_loop( - calibration_data, tokenizer, chunk_size=calib_config.chunk_size - ) - # Get sparse attention modules - sparse_modules = [ - (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) - ] + # Initialize results + threshold_scale_factor: dict[str, float] = {} + calibration_results: dict[str, Any] = {} - if not sparse_modules: - print("No sparse attention modules found for calibration") - return {} + # Run prefill calibration if enabled + if calibrate_prefill: + print("\n" + "=" * 60) + print("PREFILL PHASE CALIBRATION") + print("=" * 60) - print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") + assert calibration_data is not None, "calibration_data must be built before prefill" + prefill_forward_loop = forward_loop or create_calibration_forward_loop( + calibration_data, tokenizer, chunk_size=calib_config.chunk_size + ) - # Run calibration - calibrator = DynamicThresholdCalibrator( - target_sparse_ratio=calib_config.target_sparse_ratio, - threshold_trials=calib_config.threshold_trials, - ) - calibration_result = calibrator.calibrate(model, forward_loop) + prefill_calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=target_dict, + threshold_trials=calib_config.threshold_trials, + ) + prefill_result = prefill_calibrator.calibrate(model, prefill_forward_loop, phase="prefill") + + if "scale_factor" in prefill_result: + threshold_scale_factor["prefill"] = prefill_result["scale_factor"] + calibration_results["prefill"] = prefill_result + else: + warnings.warn("Prefill calibration did not produce valid results") + + # Run decode calibration if enabled + if calibrate_decode: + print("\n" + "=" * 60) + print("DECODE PHASE CALIBRATION") + print("=" * 60) + + assert calibration_data is not None, "calibration_data must be built before decode" + decode_forward_loop = create_decode_calibration_forward_loop( + calibration_data, tokenizer, num_decode_tokens=calib_config.num_decode_tokens + ) - # Print calibration statistics (regardless of success/failure for debugging) - print("\nCalibration complete!") - print_sparse_attention_summary(model) + decode_calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=target_dict, + threshold_trials=calib_config.threshold_trials, + ) + decode_result = decode_calibrator.calibrate(model, decode_forward_loop, phase="decode") - if "scale_factor" not in calibration_result: - warnings.warn("Calibration did not produce valid results") + if "scale_factor" in decode_result: + threshold_scale_factor["decode"] = decode_result["scale_factor"] + calibration_results["decode"] = decode_result + else: + warnings.warn("Decode calibration did not produce valid results") + + # Check if any calibration succeeded + if not threshold_scale_factor: + warnings.warn("No calibration produced valid results") return {} - # Apply calibrated scale factor to all modules - scale_factor = calibration_result["scale_factor"] - print(f"\nApplying calibrated scale factor={scale_factor:.6f} to {len(sparse_modules)} modules") + # Apply combined threshold_scale_factor dict to all modules + print("\n" + "=" * 60) + print("APPLYING CALIBRATION RESULTS") + print("=" * 60) + print(f"Applying threshold_scale_factor to {len(sparse_modules)} modules:") + for phase, scale_factor in threshold_scale_factor.items(): + print(f" {phase}: {scale_factor:.6f}") for module_name, module in sparse_modules: - module._sparse_method_instance.threshold_scale_factor = scale_factor + module._sparse_method_instance.threshold_scale_factor = threshold_scale_factor + + # Print final summary + print("\nCalibration complete!") + print_sparse_attention_summary(model) - return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}} + return { + "threshold_scale_factor": threshold_scale_factor, + "calibration_results": calibration_results, + } diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index 39abdbca8..60e82ff4b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -48,20 +48,18 @@ class SampleSparsity: def __init__( self, - target_sparse_ratio: float = 0.5, + target_sparse_ratio: dict[str, float] | None = None, threshold_trials: list[float] | None = None, ): """Initialize dynamic threshold calibrator. Args: - target_sparse_ratio: Target sparsity ratio (0.0 to 1.0) + target_sparse_ratio: Target sparsity ratio dict with 'prefill' and 'decode' keys. + Each value should be in range (0.0 to 1.0). Set to 0.0 to skip that phase. threshold_trials: List of thresholds to try during calibration - - Note: - Calibration only supports prefill phase (seq_len > 1). - Decode phase uses the same calibrated threshold. """ self.target_sparse_ratio = target_sparse_ratio + self._target_sparse_ratio_dict = self.target_sparse_ratio # Default threshold trials if not provided self.threshold_trials = threshold_trials or [ @@ -82,7 +80,7 @@ def __init__( # Statistics tracking self.sparsity_results = [] - def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: + def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dict[str, Any]: """Find optimal 'a' parameter for length-based threshold. Algorithm: @@ -93,28 +91,39 @@ def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: Args: model: The model with sparse attention modules forward_loop: Callable that takes model and forwards calibration data + phase: Phase to calibrate ('prefill' or 'decode') + + Returns: + Dict with calibration results including scale_factor, or empty dict if failed """ + assert self._target_sparse_ratio_dict is not None, "target_sparse_ratio must be provided" + target_sparsity = self._target_sparse_ratio_dict[phase] + # Extract attention modules attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] if not attention_modules: raise ValueError("No sparse attention modules found for calibration") - print("Starting dynamic threshold calibration") - print(f"Target sparsity: {self.target_sparse_ratio}") + print(f"Starting dynamic threshold calibration ({phase} phase)") + print(f"Target sparsity: {target_sparsity}") print(f"Threshold trials: {len(self.threshold_trials)}") # Stage 1: Collect sparsity for all sample-threshold pairs - print("\nStage 1: Collecting sparsity data...") + print(f"\nStage 1: Collecting {phase} sparsity data...") # Run first threshold to discover samples and initialize results self._set_threshold(attention_modules, self.threshold_trials[0]) self._enable_calibration_mode(attention_modules) with torch.no_grad(): forward_loop(model) - per_sample_stats = self._extract_calibration_stats(attention_modules) + per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) self._disable_calibration_mode(attention_modules) + if not per_sample_stats: + warnings.warn(f"No {phase} phase statistics collected. Check forward loop.") + return {} + # Initialize sparsity_results with sample info self.sparsity_results = [ self.SampleSparsity( @@ -125,29 +134,28 @@ def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: ] # Collect remaining thresholds - for threshold in tqdm(self.threshold_trials[1:], desc="Testing thresholds"): + for threshold in tqdm(self.threshold_trials[1:], desc=f"Testing thresholds ({phase})"): self._set_threshold(attention_modules, threshold) self._enable_calibration_mode(attention_modules) with torch.no_grad(): forward_loop(model) - per_sample_stats = self._extract_calibration_stats(attention_modules) + per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) self._disable_calibration_mode(attention_modules) for sample_idx, sample_stat in enumerate(per_sample_stats): - self.sparsity_results[sample_idx].threshold_sparsities[threshold] = sample_stat[ - "sparsity" - ] + if sample_idx < len(self.sparsity_results): + self.sparsity_results[sample_idx].threshold_sparsities[threshold] = sample_stat[ + "sparsity" + ] if not self.sparsity_results: - warnings.warn("No valid sparsity measurements collected during calibration") + warnings.warn(f"No valid {phase} sparsity measurements collected during calibration") return {} print(f"Collected statistics for {len(self.sparsity_results)} samples") # Stage 2: Find optimal threshold for each sample and compute 'a' - print( - f"\nStage 2: Finding 'a' parameter for target sparsity {self.target_sparse_ratio:.2f}" - ) + print(f"\nStage 2: Finding 'a' parameter for target sparsity {target_sparsity:.2f}") # Find optimal threshold for each sample optimal_pairs = [] @@ -155,7 +163,7 @@ def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: # Find threshold closest to target sparsity best_threshold, achieved_sparsity = min( sample_result.threshold_sparsities.items(), - key=lambda item: abs(item[1] - self.target_sparse_ratio), + key=lambda item: abs(item[1] - target_sparsity), ) optimal_pairs.append( @@ -163,13 +171,13 @@ def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: "length": sample_result.length, "optimal_threshold": best_threshold, "achieved_sparsity": achieved_sparsity, - "target_sparsity": self.target_sparse_ratio, + "target_sparsity": target_sparsity, } ) if not optimal_pairs: warnings.warn( - f"No optimal threshold pairs found for target sparsity {self.target_sparse_ratio}. " + f"No optimal threshold pairs found for {phase} target sparsity {target_sparsity}. " f"Collected {len(self.sparsity_results)} samples but none achieved target sparsity." ) return {} @@ -198,40 +206,28 @@ def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: # Calculate average achieved sparsity avg_achieved_sparsity = np.mean([p["achieved_sparsity"] for p in optimal_pairs]) - print("\nCalibration Results:") + print(f"\n{phase.capitalize()} Calibration Results:") print(f" Threshold scale factor: {scale_factor:.6f} (std: {scale_factor_std:.6f})") print(f" R-squared: {r_squared:.4f}") print( - f" Average achieved sparsity: {avg_achieved_sparsity:.2%} (target: {self.target_sparse_ratio:.2%})" + f" Average achieved sparsity: {avg_achieved_sparsity:.2%} (target: {target_sparsity:.2%})" ) print(f"\nExample thresholds with λ = {scale_factor:.6f} / length:") for length in [1024, 2048, 4096, 8192, 16384]: print(f" Length {length:5d}: threshold = {scale_factor / length:.2e}") - # Apply the calibrated scale factor to modules - self._apply_length_based_calibration(attention_modules, scale_factor) - return { + "phase": phase, "scale_factor": scale_factor, "scale_factor_std": scale_factor_std, "r_squared": r_squared, "num_samples": len(optimal_pairs), - "target_sparsity": self.target_sparse_ratio, + "target_sparsity": target_sparsity, "avg_achieved_sparsity": avg_achieved_sparsity, "optimal_pairs": optimal_pairs, "calibration_type": "length_based_dynamic", } - def _apply_length_based_calibration(self, modules: list[nn.Module], scale_factor: float): - """Apply calibrated threshold scale factor to modules. - - Args: - modules: List of attention modules - scale_factor: Calibrated scale factor for λ = scale_factor / length - """ - for module in modules: - module._sparse_method_instance.threshold_scale_factor = scale_factor - def _enable_calibration_mode(self, modules: list[nn.Module]): """Enable calibration mode on sparse attention modules.""" for idx, module in enumerate(modules): @@ -256,11 +252,15 @@ def _disable_calibration_mode(self, modules: list[nn.Module]): module._sparse_method_instance.set_calibration_mode(False) - def _extract_calibration_stats(self, modules: list[nn.Module]) -> list[dict]: + def _extract_calibration_stats( + self, modules: list[nn.Module], phase: str | None = None + ) -> list[dict]: """Extract per-sample calibration statistics from modules. Args: modules: List of attention modules + phase: Optional phase to filter by ('prefill' or 'decode'). + If None, returns all stats. Returns: List of per-sample statistics across all modules @@ -273,7 +273,7 @@ def _extract_calibration_stats(self, modules: list[nn.Module]) -> list[dict]: if not hasattr(module, "_stats_manager") or module._stats_manager is None: continue - manager_stats = module._stats_manager.get_calibration_stats() + manager_stats = module._stats_manager.get_calibration_stats(phase) if manager_stats: all_per_sample_stats.append(manager_stats) diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 8271dd4a2..f98c51d34 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -150,13 +150,18 @@ class CalibrationConfig(ModeloptBaseConfig): """Configuration for automatic threshold calibration using RULER dataset. Calibration learns a dynamic threshold λ = scale_factor / sequence_length that - achieves target sparsity. Only supports prefill phase (seq_len > 1). + achieves target sparsity. Supports both prefill and decode phases with per-phase + target sparsity ratios. """ - target_sparse_ratio: float = ModeloptField( - default=0.5, + target_sparse_ratio: dict[str, float] = ModeloptField( + default={"prefill": 0.5, "decode": 0.5}, title="Target sparsity ratio", - description="Target ratio of sparse attention blocks (0.0 to 1.0).", + description=( + "Target ratio of sparse attention blocks (0.0 to 1.0). " + "Dict with 'prefill' and 'decode' keys for per-phase targets. " + "Set a phase value to 0.0 to skip calibration for that phase." + ), ) samples: int = ModeloptField( @@ -187,6 +192,12 @@ class CalibrationConfig(ModeloptBaseConfig): ), ) + num_decode_tokens: int = ModeloptField( + default=10, + title="Number of decode tokens", + description="Number of decode tokens to generate for decode phase calibration.", + ) + threshold_trials: list[float] | None = ModeloptField( default=None, title="Threshold trials", @@ -217,9 +228,24 @@ def validate_threshold_trials(cls, v): @field_validator("target_sparse_ratio") @classmethod def validate_target_sparse_ratio(cls, v): - """Validate target sparsity ratio is between 0 and 1.""" - if not 0.0 <= v <= 1.0: - raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") + """Validate target sparsity ratio dict.""" + if not isinstance(v, dict): + raise ValueError( + f"target_sparse_ratio must be a dict with 'prefill' and 'decode' keys, got {type(v)}" + ) + # Validate phase keys + valid_phases = {"prefill", "decode"} + invalid_keys = set(v.keys()) - valid_phases + if invalid_keys: + raise ValueError( + f"Invalid target_sparse_ratio phases: {invalid_keys}. Valid phases: {valid_phases}" + ) + # Validate all values are in range [0, 1] + for phase, ratio in v.items(): + if not isinstance(ratio, (int, float)) or not 0.0 <= ratio <= 1.0: + raise ValueError( + f"target_sparse_ratio for phase '{phase}' must be between 0.0 and 1.0, got {ratio}" + ) return v @field_validator("samples") @@ -254,6 +280,14 @@ def validate_chunk_size(cls, v): raise ValueError(f"chunk_size must be positive or -1 (disabled), got {v}") return v + @field_validator("num_decode_tokens") + @classmethod + def validate_num_decode_tokens(cls, v): + """Validate num_decode_tokens is positive.""" + if v <= 0: + raise ValueError(f"num_decode_tokens must be positive, got {v}") + return v + class SparseAttentionConfig(ModeloptBaseConfig): """Base configuration for sparse attention optimization. @@ -333,8 +367,9 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): SKIP_SOFTMAX_CALIB = { "sparse_cfg": { "calibration": { - "target_sparse_ratio": 0.75, - "samples": 16, + # "target_sparse_ratio": {"prefill": 0.75, "decode": 0.75}, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, + "samples": 64, "max_seqlen": 16384, }, "*attn*": { diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index aa3eb7c29..0d5854532 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -307,7 +307,10 @@ def _format_threshold(info: dict) -> str: """Format threshold info for display.""" t = info.get("type") if t == "dynamic": - return f"λ={info.get('scale_factor', 0):.2f}" + # Per-phase calibrated threshold: λ = scale_factor[phase] / length + scale_factors = info.get("scale_factors", {}) + parts = [f"{phase}={sf:.2f}" for phase, sf in scale_factors.items()] + return f"λ={{{', '.join(parts)}}}" if t == "static": v = info.get("value") if isinstance(v, dict): diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 0e4512c98..cd090ac65 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -141,11 +141,11 @@ def calc_correction_factor_and_p( # Calculate threshold threshold_scale_factor = getattr(self, "threshold_scale_factor", None) - if threshold_scale_factor: - # Use calibrated dynamic threshold: λ = scale_factor / length - log_threshold = np.log(threshold_scale_factor / seq_k) + if threshold_scale_factor is not None and phase in threshold_scale_factor: + # Per-phase calibrated threshold: λ = scale_factor[phase] / length + log_threshold = np.log(threshold_scale_factor[phase] / seq_k) else: - # Use static threshold from config + # Use static threshold from config (no calibration or phase not calibrated) log_threshold = np.log(self.threshold) if phase == "prefill": @@ -191,15 +191,16 @@ def calc_correction_factor_and_p( element_mask = element_mask[:, :, :seq_q, :seq_k] # Step 8: Calculate sparsity statistics - # density = sum(mask) / numel(mask) * N / (N+1) for causal if self.is_causal: - density = float(block_mask.sum() / block_mask.numel()) * ( - num_block_rows / (num_block_rows + 1) - ) + # For causal attention, only count lower triangle blocks (including diagonal) + num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2 + total_valid_blocks = batch_size * num_heads * num_causal_blocks + density = float(block_mask.sum()) / total_valid_blocks + total_blocks = num_causal_blocks else: density = float(block_mask.sum() / block_mask.numel()) + total_blocks = num_block_rows * num_block_cols sparsity = 1 - density - total_blocks = num_block_rows * num_block_cols else: # decode blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( attn_weights, 1, self.bc @@ -315,17 +316,21 @@ def get_threshold_info(self) -> dict[str, Any]: threshold_scale_factor = getattr(self, "threshold_scale_factor", None) if threshold_scale_factor is not None: - # Calibrated dynamic threshold + # Per-phase calibrated dynamic threshold + example_lengths = [1024, 2048, 4096, 8192] + phase_info = {} + for phase, scale_factor in threshold_scale_factor.items(): + phase_info[phase] = { + "scale_factor": scale_factor, + "example_thresholds": { + length: scale_factor / length for length in example_lengths + }, + } return { "type": "dynamic", - "scale_factor": threshold_scale_factor, - "formula": "λ / length", - "example_lengths": { - 1024: threshold_scale_factor / 1024, - 2048: threshold_scale_factor / 2048, - 4096: threshold_scale_factor / 4096, - 8192: threshold_scale_factor / 8192, - }, + "scale_factors": threshold_scale_factor, + "formula": "λ[phase] / length", + "phases": phase_info, } else: # Static threshold (single value or phase-specific dict) diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 996095127..bea775592 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -26,6 +26,9 @@ class SparseAttentionMethod(ABC): """Base class for sparse attention methods.""" + # Flag to indicate calibration mode (set by calibrator) + _calibration_mode: bool = False + @abstractmethod def calculate_sparsity( self, diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 29a73218b..281e11e7d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -208,6 +208,11 @@ def sparse_softmax(input, dim=-1, *args, **kwargs): # Store stats for collection self._last_stats = stats + # Only apply sparsity mask after calibration (not during calibration) + # During calibration, we measure sparsity without modifying the output + if not self._sparse_method_instance._calibration_mode: + input = self._sparse_method_instance.apply_sparsity(input, sparse_mask) + return original_softmax(input, dim, *args, **kwargs) return sparse_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py index 9fc57a0b1..9f37df61f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -127,11 +127,17 @@ def reset(self): } self.per_sample_stats = [] - def get_calibration_stats(self) -> list[dict]: - """Get per-sample calibration statistics. + def get_calibration_stats(self, phase: str | None = None) -> list[dict]: + """Get per-sample calibration statistics, optionally filtered by phase. + + Args: + phase: Optional phase to filter by ('prefill' or 'decode'). + If None, returns all stats. Returns: List of per-sample statistics dictionaries. - Empty list if not in calibration mode. + Empty list if not in calibration mode or no stats for that phase. """ - return self.per_sample_stats + if phase is None: + return self.per_sample_stats + return [s for s in self.per_sample_stats if s.get("phase") == phase] From 5fd37f05cf55ffbd3449419837d7165d15d597fe Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 15 Dec 2025 07:47:13 +0000 Subject: [PATCH 3/9] Add hf unified checkpoint export for sparse attention Signed-off-by: Kai Xu --- .vscode/settings.json | 3 - modelopt/torch/export/unified_export_hf.py | 11 + .../calibration/calibrate.py | 67 ++--- .../calibration/calibrator.py | 6 +- .../attention_sparsity/calibration/dataset.py | 5 +- .../calibration/ruler_utils.py | 5 +- .../sparsity/attention_sparsity/conversion.py | 25 +- .../methods/flash_skip_softmax.py | 4 +- .../attention_sparsity/methods/registry.py | 7 +- .../attention_sparsity/model_sparsify.py | 2 +- .../test_calibration_gpu.py | 14 +- .../test_sparse_attention_calibration.py | 276 +++++++++++++++--- .../test_sparse_attention_conversion.py | 30 +- .../attention_sparsity/test_threshold_info.py | 55 ++-- 14 files changed, 352 insertions(+), 158 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 1cff4a791..0e8465ad3 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -45,7 +45,4 @@ ], "git.alwaysSignOff": true, "git.enableCommitSigning": true, - "cursorpyright.analysis.extraPaths": [ - "./tests/" - ], } diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 011af533d..e5c77bb77 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -54,6 +54,11 @@ from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names +try: + from modelopt.torch.sparsity.attention_sparsity.conversion import export_sparse_attention_config +except ImportError: + export_sparse_attention_config = None + from .convert_hf_config import convert_hf_quant_config_format from .layer_utils import ( get_expert_linear_names, @@ -991,6 +996,12 @@ def export_hf_checkpoint( if hf_quant_config is not None: config_data["quantization_config"] = hf_quant_config + # Add sparse attention config if available + if export_sparse_attention_config is not None: + sparse_attn_config = export_sparse_attention_config(model) + if sparse_attn_config is not None: + config_data["sparse_attention_config"] = sparse_attn_config + with open(original_config, "w") as file: json.dump(config_data, file, indent=4) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index bb6d511ab..515985244 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -79,20 +79,6 @@ def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | N return CalibrationConfig(**calib_dict) -def _parse_target_sparse_ratio( - target_sparse_ratio: dict[str, float], -) -> dict[str, float]: - """Parse target_sparse_ratio dict. - - Args: - target_sparse_ratio: Target sparsity ratio dict with 'prefill' and 'decode' keys - - Returns: - Dict with 'prefill' and 'decode' keys - """ - return target_sparse_ratio - - def create_calibration_forward_loop( calibration_data: list[dict[str, Any]], tokenizer_name_or_path: str, @@ -185,30 +171,31 @@ def forward_loop(model: nn.Module) -> None: original_attn_impl = getattr(model.config, "_attn_implementation", "eager") with torch.no_grad(): - # Step 1: Fast prefill with flash attention (no measurement) - model.config._attn_implementation = "flash_attention_2" - outputs = model(input_ids, use_cache=True) - past_key_values = outputs.past_key_values - - # Step 2: Switch to eager for decode (enables softmax hook) - model.config._attn_implementation = "eager" - - # Step 3: Manual decode loop for explicit control over token generation - # model.generate() method is not used here because it doesn't allow explicit control over KV cache - # Get the last token's logits and sample next token - next_token = outputs.logits[:, -1:, :].argmax(dim=-1) - - for _ in range(num_decode_tokens): - outputs = model( - next_token, - past_key_values=past_key_values, - use_cache=True, - ) + try: + # Step 1: Fast prefill with flash attention (no measurement) + model.config._attn_implementation = "flash_attention_2" + outputs = model(input_ids, use_cache=True) past_key_values = outputs.past_key_values + + # Step 2: Switch to eager for decode (enables softmax hook) + model.config._attn_implementation = "eager" + + # Step 3: Manual decode loop for explicit control over token generation + # model.generate() method is not used here because it doesn't allow explicit control over KV cache + # Get the last token's logits and sample next token next_token = outputs.logits[:, -1:, :].argmax(dim=-1) - # Restore original attention implementation - model.config._attn_implementation = original_attn_impl + for _ in range(num_decode_tokens): + outputs = model( + next_token, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + next_token = outputs.logits[:, -1:, :].argmax(dim=-1) + finally: + # Restore original attention implementation + model.config._attn_implementation = original_attn_impl # Clean up del past_key_values @@ -242,8 +229,8 @@ def calibrate_sparse_attention( if calib_config is None: return {} - # Parse target_sparse_ratio into per-phase targets - target_dict = _parse_target_sparse_ratio(calib_config.target_sparse_ratio) + # Get per-phase targets + target_dict = calib_config.target_sparse_ratio calibrate_prefill = target_dict.get("prefill", 0.0) > 0.0 calibrate_decode = target_dict.get("decode", 0.0) > 0.0 @@ -288,7 +275,8 @@ def calibrate_sparse_attention( print("PREFILL PHASE CALIBRATION") print("=" * 60) - assert calibration_data is not None, "calibration_data must be built before prefill" + if calibration_data is None: + raise RuntimeError("calibration_data must be built before prefill") prefill_forward_loop = forward_loop or create_calibration_forward_loop( calibration_data, tokenizer, chunk_size=calib_config.chunk_size ) @@ -311,7 +299,8 @@ def calibrate_sparse_attention( print("DECODE PHASE CALIBRATION") print("=" * 60) - assert calibration_data is not None, "calibration_data must be built before decode" + if calibration_data is None: + raise RuntimeError("calibration_data must be built before decode") decode_forward_loop = create_decode_calibration_forward_loop( calibration_data, tokenizer, num_decode_tokens=calib_config.num_decode_tokens ) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index 60e82ff4b..07aae4b70 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -59,7 +59,6 @@ def __init__( threshold_trials: List of thresholds to try during calibration """ self.target_sparse_ratio = target_sparse_ratio - self._target_sparse_ratio_dict = self.target_sparse_ratio # Default threshold trials if not provided self.threshold_trials = threshold_trials or [ @@ -96,8 +95,9 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic Returns: Dict with calibration results including scale_factor, or empty dict if failed """ - assert self._target_sparse_ratio_dict is not None, "target_sparse_ratio must be provided" - target_sparsity = self._target_sparse_ratio_dict[phase] + if self.target_sparse_ratio is None: + raise RuntimeError("target_sparse_ratio must be provided") + target_sparsity = self.target_sparse_ratio[phase] # Extract attention modules attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py index 7603b4e1d..74a4f3aa3 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -435,11 +435,12 @@ def _generate_qa_sample( # Create a simple QA pair answer = self._generate_random_phrase() - question = f"What is the special code mentioned in document {random.randint(1, num_docs)}?" + answer_doc_idx = random.randint(0, num_docs - 1) + question = f"What is the special code mentioned in document {answer_doc_idx + 1}?" for i in range(num_docs): doc_text = self._generate_document_text(200) # Base document - if i == 2: # Insert answer in one document + if i == answer_doc_idx: # Insert answer in the correct document doc_text += f" The special code is {answer}. " documents.append(f"Document {i + 1}:\n{doc_text}\n") diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py index 70d4da81b..9de75c02a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -331,7 +331,7 @@ def generate_niah_sample( context = " ".join(document_sents_list) - if type_haystack == "noise": + elif type_haystack == "noise": haystack_sent = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." sentences = [haystack_sent] * num_haystack indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) @@ -354,6 +354,9 @@ def generate_niah_sample( sentences.insert(index, element) context = "\n".join(sentences) + else: + raise ValueError(f"Unknown haystack type: {type_haystack}") + # Generate query and answer indices = random.sample(range(num_needle_k), num_needle_q) queries = [keys[i] for i in indices] diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 0d5854532..bc90406ed 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -226,8 +226,6 @@ def update_sparse_attention_metadata( if isinstance(module, SparseAttentionModule): module_name = get_unwrapped_name(name, model) - # Save the method configuration that was used - # _method_config already contains the validated config dict # Save the method configuration that was used # _method_config already contains the validated config dict module_state = { @@ -243,6 +241,29 @@ def update_sparse_attention_metadata( ) +def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: + """Extract sparse attention config for export to config.json. + + Extracts the global threshold_scale_factor from the first sparse attention + module that has calibrated thresholds. + + Args: + model: Model with sparse attention applied + + Returns: + Dictionary with sparse attention config, or None if no calibrated config found. + Format: {"threshold_scale_factor": {"prefill": float, "decode": float}} + """ + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + threshold_scale_factor = getattr( + module._sparse_method_instance, "threshold_scale_factor", None + ) + if threshold_scale_factor is not None: + return {"threshold_scale_factor": threshold_scale_factor} + return None + + def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): """Disable sparse attention for matching modules. diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index cd090ac65..ec006e58a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -43,6 +43,7 @@ def __init__(self, method_config: dict | None = None): method_config: Configuration dict with threshold, br, bc, is_causal, etc. All required fields should have defaults from SparseAttentionAttributeConfig. """ + super().__init__() config = method_config or {} # Extract configuration @@ -55,9 +56,6 @@ def __init__(self, method_config: dict | None = None): # Optional parameters not in Pydantic config self.phase = config.get("phase", None) - # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold - self._calibration_mode = False - # Initialize threshold if isinstance(self.threshold_config, dict): self.threshold = self.threshold_config.get( diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index bea775592..14ee4ce61 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -26,8 +26,11 @@ class SparseAttentionMethod(ABC): """Base class for sparse attention methods.""" - # Flag to indicate calibration mode (set by calibrator) - _calibration_mode: bool = False + def __init__(self): + """Initialize base sparse attention method.""" + # Flag to indicate calibration mode (set by calibrator) + # Instance attribute to prevent shared state across multiple models + self._calibration_mode: bool = False @abstractmethod def calculate_sparsity( diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py index b6b1e809f..b79e25bd8 100644 --- a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -136,7 +136,7 @@ def forward_loop(model) -> float: from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_pretrained(b + model = AutoModelForCausalLM.from_pretrained( model_path, attn_implementation="eager", # Required for sparse attention torch_dtype=torch.bfloat16, diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py index 913dc24a0..762bcc1d0 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -124,7 +124,7 @@ def test_calibration_simple_model(self, simple_model): "backend": "pytorch", "enable": True, "calibration": { - "target_sparse_ratio": 0.5, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, "samples": 4, "max_seqlen": 1024, }, @@ -162,7 +162,7 @@ def test_calibration_pytorch_backend(self, simple_model): "backend": "pytorch", "enable": True, "calibration": { - "target_sparse_ratio": 0.5, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, "samples": 2, "max_seqlen": 1024, }, @@ -193,7 +193,7 @@ def test_simplified_calibration(self, simple_model): "threshold": 1e-3, "enable": True, "calibration": { - "target_sparse_ratio": 0.5, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, "samples": 4, "max_seqlen": 1024, }, @@ -220,7 +220,7 @@ def test_calibration_persistence(self, simple_model): "threshold": 1e-3, "enable": True, "calibration": { - "target_sparse_ratio": 0.5, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, "samples": 2, "max_seqlen": 1024, }, @@ -268,7 +268,7 @@ def test_calibrated_model_inference(self, simple_model_setup): "backend": "pytorch", "enable": True, "calibration": { - "target_sparse_ratio": 0.5, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, "samples": 2, "max_seqlen": 1024, }, @@ -303,7 +303,7 @@ def test_calibrated_vs_fixed_threshold(self, simple_model_setup): "threshold": 1e-3, "enable": True, "calibration": { - "target_sparse_ratio": 0.5, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, "samples": 2, "max_seqlen": 1024, }, @@ -362,7 +362,7 @@ def test_memory_usage(self, simple_model_setup): "threshold": 1e-3, "enable": True, "calibration": { - "target_sparse_ratio": 0.5, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, "samples": 2, "max_seqlen": 1024, }, diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py index 4558ca22b..c44a45bbe 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -203,30 +203,7 @@ def test_uneven_sample_distribution(self): class TestDynamicThresholdCalibrator: - """Test calibration algorithm correctness.""" - - def test_calibrator_initialization(self): - """Test that calibrator initializes correctly.""" - calibrator = DynamicThresholdCalibrator( - target_sparse_ratio=0.5, - threshold_trials=[1e-4, 1e-3, 1e-2], - ) - - assert calibrator.target_sparse_ratio == 0.5 - assert len(calibrator.threshold_trials) == 3 - - def test_calibrator_default_threshold_trials(self): - """Test that calibrator has default threshold trials.""" - calibrator = DynamicThresholdCalibrator( - target_sparse_ratio=0.5, - ) - - # Should have default threshold trials - assert calibrator.threshold_trials is not None - assert len(calibrator.threshold_trials) == 12 - # Check they are positive and in valid range - trials = calibrator.threshold_trials - assert all(0 < t < 1 for t in trials) + """Test calibration algorithm correctness (regression calculations).""" def test_regression_calculation_synthetic(self): """Test 'a' parameter calculation with synthetic data.""" @@ -344,7 +321,7 @@ def test_sparsify_with_calibration_requires_forward_loop(self): config = { "sparse_cfg": { "calibration": { - "target_sparse_ratio": 0.5, + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, "samples": 4, "max_seqlen": 1024, }, @@ -384,41 +361,49 @@ def test_calibration_config_validation(self): """Test CalibrationConfig validation.""" # Valid config config = CalibrationConfig( - target_sparse_ratio=0.5, + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, samples=48, max_seqlen=32768, ) - assert config.target_sparse_ratio == 0.5 + assert config.target_sparse_ratio == {"prefill": 0.5, "decode": 0.5} assert config.samples == 48 assert config.max_seqlen == 32768 # Invalid target_sparse_ratio (> 1.0) - with pytest.raises(ValueError, match="target_sparse_ratio must be between"): - CalibrationConfig(target_sparse_ratio=1.5, samples=48, max_seqlen=32768) + with pytest.raises(ValueError, match="target_sparse_ratio.*must be between 0.0 and 1.0"): + CalibrationConfig( + target_sparse_ratio={"prefill": 1.5, "decode": 0.5}, samples=48, max_seqlen=32768 + ) # Invalid target_sparse_ratio (< 0.0) - with pytest.raises(ValueError, match="target_sparse_ratio must be between"): - CalibrationConfig(target_sparse_ratio=-0.1, samples=48, max_seqlen=32768) + with pytest.raises(ValueError, match="target_sparse_ratio.*must be between 0.0 and 1.0"): + CalibrationConfig( + target_sparse_ratio={"prefill": -0.1, "decode": 0.5}, samples=48, max_seqlen=32768 + ) # Invalid samples with pytest.raises(ValueError, match="samples must be positive"): - CalibrationConfig(target_sparse_ratio=0.5, samples=0, max_seqlen=32768) + CalibrationConfig( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, samples=0, max_seqlen=32768 + ) # Invalid max_seqlen with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): - CalibrationConfig(target_sparse_ratio=0.5, samples=48, max_seqlen=512) + CalibrationConfig( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, samples=48, max_seqlen=512 + ) def test_threshold_trials_validation(self): """Test threshold_trials validation.""" # Valid custom threshold_trials config = CalibrationConfig( - target_sparse_ratio=0.5, + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, threshold_trials=[1e-5, 1e-4, 1e-3, 1e-2], ) assert config.threshold_trials == [1e-5, 1e-4, 1e-3, 1e-2] # None (use defaults) - config_default = CalibrationConfig(target_sparse_ratio=0.5) + config_default = CalibrationConfig(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) assert config_default.threshold_trials is None # Invalid: empty list @@ -462,7 +447,7 @@ def test_set_threshold(self): assert len(modules) > 0 # Create calibrator and set threshold - calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) calibrator._set_threshold(modules, 0.05) # Verify threshold was set @@ -487,7 +472,7 @@ def test_enable_disable_calibration_mode(self): modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] - calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) # Enable calibration mode calibrator._enable_calibration_mode(modules) @@ -523,7 +508,7 @@ def test_extract_calibration_stats_no_stats(self): modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] - calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) # Extract stats without running any forward passes stats = calibrator._extract_calibration_stats(modules) @@ -531,16 +516,213 @@ def test_extract_calibration_stats_no_stats(self): # Should return empty list assert stats == [] - def test_calibrator_with_single_sample(self): - """Test calibrator edge case with only one sample.""" + def test_enable_calibration_mode_with_existing_stats_manager(self): + """Test _enable_calibration_mode when stats manager already exists.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + "collect_stats": True, # Enable stats manager initially + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + # Stats manager should already exist + for module in modules: + assert module._stats_manager is not None + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + + # Disable stats manager first + for module in modules: + module._stats_manager.enabled = False + + # Enable calibration mode - should re-enable existing stats manager + calibrator._enable_calibration_mode(modules) + + for module in modules: + assert module._stats_manager.enabled is True + assert module._stats_manager.calibration_mode is True + + def test_extract_calibration_stats_with_phase_filter(self): + """Test _extract_calibration_stats with phase filtering.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + + # Enable calibration mode and manually add some stats + calibrator._enable_calibration_mode(modules) + + for module in modules: + # Manually add stats for different phases + module._stats_manager.per_sample_stats = [ + {"sparsity": 0.3, "sample_length": 1024, "phase": "prefill"}, + {"sparsity": 0.4, "sample_length": 2048, "phase": "prefill"}, + {"sparsity": 0.5, "sample_length": 100, "phase": "decode"}, + ] + + # Extract only prefill stats + prefill_stats = calibrator._extract_calibration_stats(modules, phase="prefill") + assert len(prefill_stats) == 2 + assert prefill_stats[0]["sample_length"] == 1024 + assert prefill_stats[1]["sample_length"] == 2048 + + # Extract only decode stats + decode_stats = calibrator._extract_calibration_stats(modules, phase="decode") + assert len(decode_stats) == 1 + assert decode_stats[0]["sample_length"] == 100 + + def test_extract_calibration_stats_module_without_stats_manager(self): + """Test _extract_calibration_stats with module missing stats manager.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + "collect_stats": False, # No stats manager + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + + # Stats manager should be None + for module in modules: + assert module._stats_manager is None + + # Should return empty list + stats = calibrator._extract_calibration_stats(modules) + assert stats == [] + + def test_calibrate_no_sparse_modules(self): + """Test calibrate raises error when no sparse modules found.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + # Don't apply sparse attention + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, + threshold_trials=[0.001, 0.01], + ) + + def dummy_forward_loop(m): + pass + + with pytest.raises(ValueError, match="No sparse attention modules found"): + calibrator.calibrate(model, dummy_forward_loop, "prefill") + + def test_calibrate_empty_stats(self): + """Test calibrate handles empty stats gracefully.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + calibrator = DynamicThresholdCalibrator( - target_sparse_ratio=0.5, - threshold_trials=[0.001, 0.01, 0.1], + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, + threshold_trials=[0.001], # Only one threshold for speed + ) + + # Forward loop that doesn't generate any stats + def empty_forward_loop(m): + pass + + # Should return empty dict + result = calibrator.calibrate(sparse_model, empty_forward_loop, "prefill") + assert result == {} + + def test_calibrator_default_threshold_trials_values(self): + """Test that default threshold trials have expected values.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, + ) + + # Should have 12 default trials + assert len(calibrator.threshold_trials) == 12 + + # Check specific values + expected_trials = [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 5e-2, + 1e-1, + 5e-1, + ] + assert calibrator.threshold_trials == expected_trials + + def test_calibrator_custom_threshold_trials(self): + """Test calibrator with custom threshold trials.""" + custom_trials = [0.001, 0.005, 0.01, 0.05, 0.1] + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, + threshold_trials=custom_trials, + ) + + assert calibrator.threshold_trials == custom_trials + + def test_calibrator_sparsity_results_initialization(self): + """Test that sparsity_results is initialized as empty list.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, + ) + + assert calibrator.sparsity_results == [] + + def test_sample_sparsity_dataclass(self): + """Test SampleSparsity dataclass.""" + sample = DynamicThresholdCalibrator.SampleSparsity( + length=1024, + threshold_sparsities={0.001: 0.3, 0.01: 0.5, 0.1: 0.7}, ) - # Even with one sample, regression should work - assert calibrator.target_sparse_ratio == 0.5 - assert len(calibrator.threshold_trials) == 3 + assert sample.length == 1024 + assert sample.threshold_sparsities[0.001] == 0.3 + assert sample.threshold_sparsities[0.01] == 0.5 + assert sample.threshold_sparsities[0.1] == 0.7 class TestCalibrateFunction: @@ -574,7 +756,7 @@ def test_extract_calibration_config(self): config = { "sparse_cfg": { "calibration": { - "target_sparse_ratio": 0.3, + "target_sparse_ratio": {"prefill": 0.3, "decode": 0.3}, "samples": 12, "max_seqlen": 2048, }, @@ -587,7 +769,7 @@ def test_extract_calibration_config(self): calib_config = _extract_calibration_config(config) assert calib_config is not None - assert calib_config.target_sparse_ratio == 0.3 + assert calib_config.target_sparse_ratio == {"prefill": 0.3, "decode": 0.3} assert calib_config.samples == 12 assert calib_config.max_seqlen == 2048 diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index de954ae97..49a8ccd39 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -60,28 +60,6 @@ def test_basic_replacement(self): # Verify replacement occurred assert sparse_attention_count > 0 - def test_enable_disable_toggle(self): - """Test enabling and disabling sparse attention.""" - model = SimpleAttentionModel() - model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) - - # Check initially enabled - for module in model.modules(): - if isinstance(module, SparseAttentionModule): - assert module.is_enabled - - # Disable all sparse attention modules - disable_sparse_attention(model, "*") - for module in model.modules(): - if isinstance(module, SparseAttentionModule): - assert not module.is_enabled - - # Re-enable all sparse attention modules - enable_sparse_attention(model, "*") - for module in model.modules(): - if isinstance(module, SparseAttentionModule): - assert module.is_enabled - def test_pattern_based_replacement(self): """Test pattern-based selective replacement.""" model = SimpleTransformerEncoderLayer() @@ -151,10 +129,6 @@ def test_no_matching_modules(self): def test_disable_enable_functions(self): """Test disable/enable utility functions.""" - from modelopt.torch.sparsity.attention_sparsity.conversion import ( - disable_sparse_attention, - enable_sparse_attention, - ) model = SimpleAttentionModel() model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) @@ -181,8 +155,8 @@ def test_print_sparse_attention_summary(self, capsys): # Capture output captured = capsys.readouterr() - assert "Total sparse attention modules:" in captured.out - assert "Enabled:" in captured.out + assert "Sparse attention:" in captured.out + assert "modules enabled" in captured.out def test_restore_sparse_attention_model(self): """Test save/restore via modelopt_state.""" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py index ac9f46a54..6022e1396 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -47,7 +47,7 @@ def test_static_threshold(self): assert info["value"] == 0.001 def test_phased_threshold(self): - """Test threshold info for phase-specific thresholds.""" + """Test threshold info for phase-specific static thresholds.""" method = FlashSkipSoftmax( method_config={ "threshold": {"prefill": 0.001, "decode": 0.0001}, @@ -60,11 +60,11 @@ def test_phased_threshold(self): info = method.get_threshold_info() - assert info["type"] == "static_phased" - assert "thresholds" in info - assert info["thresholds"]["prefill"] == 0.001 - assert info["thresholds"]["decode"] == 0.0001 - assert "current" in info + # Static phased thresholds are reported as type "static" with dict value + assert info["type"] == "static" + assert isinstance(info["value"], dict) + assert info["value"]["prefill"] == 0.001 + assert info["value"]["decode"] == 0.0001 def test_dynamic_calibrated_threshold(self): """Test threshold info for calibrated dynamic threshold.""" @@ -78,17 +78,21 @@ def test_dynamic_calibrated_threshold(self): } ) - # Simulate calibration setting scale factor - method.threshold_scale_factor = 437.5 + # Simulate calibration setting per-phase scale factors + method.threshold_scale_factor = {"prefill": 437.5, "decode": 500.0} info = method.get_threshold_info() assert info["type"] == "dynamic" - assert info["scale_factor"] == 437.5 - assert info["formula"] == "λ / length" - assert "example_lengths" in info - assert abs(info["example_lengths"][1024] - 437.5 / 1024) < 1e-6 - assert abs(info["example_lengths"][2048] - 437.5 / 2048) < 1e-6 + assert info["scale_factors"] == {"prefill": 437.5, "decode": 500.0} + assert info["formula"] == "λ[phase] / length" + assert "phases" in info + assert "prefill" in info["phases"] + assert "decode" in info["phases"] + # Check example thresholds for prefill + prefill_examples = info["phases"]["prefill"]["example_thresholds"] + assert abs(prefill_examples[1024] - 437.5 / 1024) < 1e-6 + assert abs(prefill_examples[2048] - 437.5 / 2048) < 1e-6 def test_threshold_info_structure(self): """Test that threshold info has expected structure.""" @@ -163,17 +167,22 @@ def test_module_with_calibrated_threshold(self): sparse_model = sparsify(model, config) - # Find module and set calibrated threshold + # Find module and set calibrated threshold (per-phase dict format) + module = None for module in sparse_model.modules(): if isinstance(module, SparseAttentionModule): - module._sparse_method_instance.threshold_scale_factor = 500.0 + module._sparse_method_instance.threshold_scale_factor = { + "prefill": 500.0, + "decode": 500.0, + } break + assert module is not None, "No SparseAttentionModule found" # Get threshold info info = module.get_threshold_info() assert info["type"] == "dynamic" - assert info["scale_factor"] == 500.0 + assert info["scale_factors"] == {"prefill": 500.0, "decode": 500.0} def test_module_without_method_instance(self): """Test get_threshold_info when sparse method instance doesn't exist.""" @@ -233,7 +242,7 @@ def test_summary_displays_static_threshold(self, capsys): print_sparse_attention_summary(sparse_model) captured = capsys.readouterr() - assert "Static (1.00e-03)" in captured.out + assert "threshold=1.00e-03" in captured.out assert "flash_skip_softmax" in captured.out def test_summary_displays_dynamic_threshold(self, capsys): @@ -258,13 +267,19 @@ def test_summary_displays_dynamic_threshold(self, capsys): sparse_model = sparsify(model, config) - # Set calibrated threshold + # Set calibrated threshold (per-phase dict format) for module in sparse_model.modules(): if isinstance(module, SparseAttentionModule): - module._sparse_method_instance.threshold_scale_factor = 437.5 + module._sparse_method_instance.threshold_scale_factor = { + "prefill": 437.5, + "decode": 500.0, + } print_sparse_attention_summary(sparse_model) captured = capsys.readouterr() - assert "Dynamic (λ=437.500000)" in captured.out + # Output format: λ={prefill=437.50, decode=500.00} + assert "λ=" in captured.out + assert "prefill=" in captured.out + assert "decode=" in captured.out assert "flash_skip_softmax" in captured.out From 8908b1a79a8ccd6f69e9ca6b115961a866b995ee Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 30 Dec 2025 08:29:10 +0000 Subject: [PATCH 4/9] Address feedbacks Signed-off-by: Kai Xu --- examples/llm_eval/lm_eval_hf.py | 8 ----- examples/llm_eval/mmlu.py | 8 ----- .../llm_sparsity/attention_sparsity/README.md | 6 ++-- .../attention_sparsity/requirements.txt | 2 -- .../attention_sparsity/calibration/dataset.py | 7 ++-- .../sparsity/attention_sparsity/config.py | 7 ++-- .../sparsity/attention_sparsity/conversion.py | 9 ++--- .../methods/flash_skip_softmax.py | 16 +++++---- .../attention_sparsity/plugins/__init__.py | 16 +++++++-- .../attention_sparsity/plugins/huggingface.py | 33 +++++++++++++++++++ setup.py | 3 ++ .../test_calibration_gpu.py | 2 +- .../test_sparse_attention_calibration.py | 2 +- .../attention_sparsity/test_threshold_info.py | 2 +- 14 files changed, 80 insertions(+), 41 deletions(-) delete mode 100644 examples/llm_sparsity/attention_sparsity/requirements.txt diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 24dcb28f6..405e8590a 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -68,14 +68,6 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | additional_config = {} if additional_config is None else additional_config additional_config = {k: v for k, v in additional_config.items() if v is not None} - # Force eager attention if sparse attention is requested - if sparse_cfg: - additional_config["attn_implementation"] = "eager" - warnings.warn( - "Sparse attention requires attn_implementation='eager'. " - "Forcing eager attention implementation." - ) - # Enable automatic save/load of modelopt state huggingface checkpointing mto.enable_huggingface_checkpointing() diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index 0bf47fcd3..316f443bb 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -269,14 +269,6 @@ def main( max_batch_size=1, ) else: - # Force eager attention if sparse attention is requested - if sparse_cfg: - kwargs["attn_implementation"] = "eager" - warnings.warn( - "Sparse attention requires attn_implementation='eager'. " - "Forcing eager attention implementation." - ) - model = select_model( max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs ) diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md index 708947683..6bd580a92 100644 --- a/examples/llm_sparsity/attention_sparsity/README.md +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -50,10 +50,12 @@ model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB) ## Prerequisites -### Install Requirements +### Local Installation + +For Hugging Face models, install Model Optimizer with `hf` dependencies using `pip` from [PyPI](https://pypi.org/project/nvidia-modelopt/) and install the requirements for the example: ```bash -pip install -r requirements.txt +pip install nvidia-modelopt[hf] ``` ### Download RULER Calibration Data (Required for Calibration) diff --git a/examples/llm_sparsity/attention_sparsity/requirements.txt b/examples/llm_sparsity/attention_sparsity/requirements.txt deleted file mode 100644 index a3e0dfa17..000000000 --- a/examples/llm_sparsity/attention_sparsity/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -nltk -wonderwords diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py index 74a4f3aa3..dc46413c1 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -20,9 +20,6 @@ from dataclasses import dataclass from typing import Any -from tqdm import tqdm -from transformers import AutoTokenizer - from . import ruler_utils @@ -232,6 +229,8 @@ def __init__( # Initialize tokenizer if isinstance(tokenizer_name_or_path, str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) else: self.tokenizer = tokenizer_name_or_path @@ -243,6 +242,8 @@ def build_calibration_dataset(self) -> list[dict[str, Any]]: Returns: List of calibration samples with 'input' and 'length' fields """ + from tqdm import tqdm + all_samples = [] # Generate calibration samples diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index f98c51d34..9452b7e6f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -167,7 +167,11 @@ class CalibrationConfig(ModeloptBaseConfig): samples: int = ModeloptField( default=24, title="Calibration samples", - description="Total number of RULER samples for calibration (distributed across length bins).", + description=( + "Total number of RULER samples for calibration (distributed across length bins). " + "Default (24) provides 1 sample per task per length bin (4 bins * 6 RULER tasks). " + "Increase for more robust calibration." + ), ) max_seqlen: int = ModeloptField( @@ -318,7 +322,6 @@ class SparseAttentionConfig(ModeloptBaseConfig): class FlashSkipSoftmaxConfig(SparseAttentionConfig): """Configuration for Flash Attention-aware softmax skip sparse attention.""" - # Override sparse_cfg with flash_skip_softmax specific defaults # Override sparse_cfg with flash_skip_softmax specific defaults sparse_cfg: SparseAttentionCfgType = ModeloptField( default={ diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index bc90406ed..006c99eea 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -23,10 +23,10 @@ from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict -from modelopt.torch.utils import get_unwrapped_name +from modelopt.torch.utils import atomic_print, get_unwrapped_name from .config import SparseAttentionConfig -from .plugins.huggingface import register_sparse_attention_on_the_fly +from .plugins import register_custom_model_plugins_on_the_fly from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry @@ -59,8 +59,8 @@ def convert_to_sparse_attention_model( # Initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model - # Register sparse attention modules dynamically - register_sparse_attention_on_the_fly(model) + # Apply custom model plugins + register_custom_model_plugins_on_the_fly(model) # Replace attention modules with sparse versions replace_sparse_attention_modules(model, version=ModeloptStateManager(model).state_version) @@ -340,6 +340,7 @@ def _format_threshold(info: dict) -> str: return "threshold=N/A" +@atomic_print def print_sparse_attention_summary(model: nn.Module): """Print summary of sparse attention modules in the model. diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index ec006e58a..7944510c5 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -166,7 +166,7 @@ def calc_correction_factor_and_p( # Used by Flash Attention to adjust running sum when max increases block_max_larger = torch.ones_like(block_max) block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] - correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item() # Step 4: Normalize attention scores by cumulative max # p represents log-space difference: log(score) - log(cummax) @@ -193,12 +193,13 @@ def calc_correction_factor_and_p( # For causal attention, only count lower triangle blocks (including diagonal) num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2 total_valid_blocks = batch_size * num_heads * num_causal_blocks - density = float(block_mask.sum()) / total_valid_blocks + dense_blocks = block_mask.sum() total_blocks = num_causal_blocks else: - density = float(block_mask.sum() / block_mask.numel()) + dense_blocks = block_mask.sum() # Keep as tensor + total_valid_blocks = block_mask.numel() total_blocks = num_block_rows * num_block_cols - sparsity = 1 - density + sparsity = 1.0 - dense_blocks.item() / total_valid_blocks else: # decode blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( attn_weights, 1, self.bc @@ -219,7 +220,7 @@ def calc_correction_factor_and_p( # Tracks how often the maximum increases (needed for Flash Attention rescaling) block_max_larger = torch.ones_like(block_max) block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] - correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item() # Step 4: Normalize scores by cumulative max # p = log(score) - log(cummax) in log-space @@ -236,8 +237,9 @@ def calc_correction_factor_and_p( element_mask = element_mask[:, :, :seq_q, :seq_k] # Step 7: Calculate sparsity statistics - density = float(block_mask.sum() / block_mask.numel()) - sparsity = 1 - density + dense_blocks = block_mask.sum() + total_valid_blocks = block_mask.numel() + sparsity = 1.0 - dense_blocks.item() / total_valid_blocks total_blocks = num_block_cols # Create stats dictionary diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py index ba8c8b821..0d43f525e 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -15,8 +15,20 @@ """Plugins for sparse attention integration with various frameworks.""" -from .huggingface import register_sparse_attention_on_the_fly +# List of model plugins that are called during conversion +# Each plugin is a callable that takes (model) and performs validation/setup +CUSTOM_MODEL_PLUGINS: set = set() + + +def register_custom_model_plugins_on_the_fly(model): + """Applies all registered custom model plugins.""" + for callback in CUSTOM_MODEL_PLUGINS: + callback(model) + + +from . import huggingface # noqa: E402 __all__ = [ - "register_sparse_attention_on_the_fly", + "CUSTOM_MODEL_PLUGINS", + "register_custom_model_plugins_on_the_fly", ] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 0c4a8baf9..a7d193347 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -15,12 +15,15 @@ """Dynamic sparse attention registration for HuggingFace models.""" +import warnings + import torch.nn as nn import transformers from modelopt.torch.opt.dynamic import DynamicModule from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from . import CUSTOM_MODEL_PLUGINS class _GenericSparseAttention(SparseAttentionModule): @@ -118,3 +121,33 @@ def _is_supported_model(model: nn.Module) -> bool: # Support any PyTorch model with attention modules return isinstance(model, nn.Module) + + +def validate_eager_attention(model: nn.Module) -> None: + """Validate and enforce eager attention for HuggingFace models. + + Sparse attention requires attn_implementation='eager' because it + patches torch.nn.functional.softmax, which is only called in eager mode. + + Args: + model: Model to validate + """ + if not isinstance(model, transformers.PreTrainedModel): + return + + attn_impl = getattr(model.config, "_attn_implementation", None) + if attn_impl and attn_impl != "eager": + warnings.warn( + f"Sparse attention requires attn_implementation='eager', but model uses '{attn_impl}'. " + "Forcing eager attention implementation." + ) + model.config._attn_implementation = "eager" + + +# Register plugins +CUSTOM_MODEL_PLUGINS.update( + [ + validate_eager_attention, + register_sparse_attention_on_the_fly, + ] +) diff --git a/setup.py b/setup.py index a87a7e93e..192501a57 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,9 @@ "huggingface_hub>=0.24.0", "peft>=0.17.0", "transformers>=4.53,<5.0", # Should match modelopt/torch/__init__.py and tox.ini + "deepspeed>=0.9.6 ; platform_system != 'Darwin' and platform_system != 'Windows'", + "nltk", + "wonderwords", ], # linter tools "dev-lint": [ diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py index 762bcc1d0..8f5396241 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -17,7 +17,7 @@ import pytest import torch -from _test_utils.torch_sparsity.sparse_attention_common import SimpleTransformerEncoderLayer +from _test_utils.torch.sparsity.sparse_attention_common import SimpleTransformerEncoderLayer import modelopt.torch.opt as mto from modelopt.torch.sparsity.attention_sparsity import sparsify diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py index c44a45bbe..168418a16 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -20,7 +20,7 @@ pytest.importorskip("transformers") import numpy as np -from _test_utils.torch_sparsity.sparse_attention_common import ( +from _test_utils.torch.sparsity.sparse_attention_common import ( SimpleAttentionModel, SimpleTransformerEncoder, ) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py index 6022e1396..aaa058cfd 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -19,7 +19,7 @@ pytest.importorskip("transformers") -from _test_utils.torch_sparsity.sparse_attention_common import SimpleAttentionModel +from _test_utils.torch.sparsity.sparse_attention_common import SimpleAttentionModel from modelopt.torch.sparsity.attention_sparsity import sparsify from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax From d984cd8593edb13b22bb8058bf1ccd1cda823f4d Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 20 Jan 2026 06:43:07 +0000 Subject: [PATCH 5/9] update default threshold_trials Signed-off-by: Kai Xu --- examples/llm_sparsity/attention_sparsity/README.md | 6 ++++-- .../sparsity/attention_sparsity/calibration/calibrator.py | 4 ++++ modelopt/torch/sparsity/attention_sparsity/config.py | 4 +++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md index 6bd580a92..b76a60fee 100644 --- a/examples/llm_sparsity/attention_sparsity/README.md +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -102,7 +102,7 @@ The calibration process: |----------|---------|-------------| | `--pyt_ckpt_path` | Required | HuggingFace model path or name | | `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` | -| `--backend` | `pytorch` | Backend: `pytorch` or `triton` | +| `--backend` | `pytorch` | Backend: `pytorch` (only supported backend) | | `--seq_len` | `2048` | Maximum sequence length for input prompts | | `--export_dir` | `None` | Directory to export the sparsified model | @@ -137,9 +137,11 @@ You can create custom sparse attention configurations: custom_config = { "sparse_cfg": { "calibration": { # Optional: omit for fixed threshold - "target_sparse_ratio": 0.5, # Target 50% sparsity + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, # Target 50% sparsity "samples": 128, # Number of calibration samples "max_seqlen": 8192, # Maximum sequence length + # Optional: customize threshold trials for calibration + "threshold_trials": [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1], }, "*attn*": { # Pattern to match attention modules "method": "flash_skip_softmax", diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index 07aae4b70..149909c56 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -71,9 +71,13 @@ def __init__( 1e-3, 5e-3, 1e-2, + 2e-2, 5e-2, 1e-1, + 2e-1, + 3e-1, 5e-1, + 7e-1, ] # Statistics tracking diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 9452b7e6f..e189fd0b0 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -207,7 +207,9 @@ class CalibrationConfig(ModeloptBaseConfig): title="Threshold trials", description=( "List of threshold values to test during calibration. " - "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" + "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, " + "1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1]. " + "Increasing the number of trials improves calibration accuracy but slows down calibration." ), ) From 42a049931fe35a5d8a68e0a0c730cd341c0c923a Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 27 Jan 2026 14:54:20 -0800 Subject: [PATCH 6/9] Update sparse attention config Signed-off-by: Kai Xu --- examples/llm_eval/sparse_attention_utils.py | 39 +------------ .../llm_sparsity/attention_sparsity/README.md | 2 +- .../sparsity/attention_sparsity/config.py | 56 +++++++++++-------- .../torch/sparsity/sparse_attention_common.py | 2 +- .../test_calibration_gpu.py | 16 +++--- .../test_integration_gpu.py | 8 +-- .../test_flash_skip_softmax.py | 43 ++++---------- .../test_sparse_attention_calibration.py | 31 ++++++---- .../test_sparse_attention_conversion.py | 12 ++-- .../attention_sparsity/test_threshold_info.py | 37 ++++-------- 10 files changed, 97 insertions(+), 149 deletions(-) diff --git a/examples/llm_eval/sparse_attention_utils.py b/examples/llm_eval/sparse_attention_utils.py index dc7a1b14e..8dc560851 100644 --- a/examples/llm_eval/sparse_attention_utils.py +++ b/examples/llm_eval/sparse_attention_utils.py @@ -17,36 +17,6 @@ import modelopt.torch.sparsity.attention_sparsity as mtsa -# Custom sparse attention configurations -CUSTOM_SPARSE_CONFIG = { - "SPARSE_CONSERVATIVE": { - "sparse_cfg": { - "*attn*": { - "method": "flash_skip_softmax", - "threshold": {"prefill": 5e-4, "decode": 1e-5}, - "br": 128, - "bc": 128, - "backend": "pytorch", - "enable": True, - }, - "default": {"enable": False}, - }, - }, - "SPARSE_AGGRESSIVE": { - "sparse_cfg": { - "*attn*": { - "method": "flash_skip_softmax", - "threshold": {"prefill": 5e-3, "decode": 5e-4}, - "br": 128, - "bc": 128, - "backend": "pytorch", - "enable": True, - }, - "default": {"enable": False}, - }, - }, -} - def _extract_model(model_obj): """Extract actual model from wrapper (HFLM or EvalModel).""" @@ -82,13 +52,10 @@ def sparsify_model( # Resolve config if isinstance(sparse_cfg, str): - # Try custom configs first - mtsa_cfg = CUSTOM_SPARSE_CONFIG.get(sparse_cfg) - if mtsa_cfg is None: - # Try predefined configs - mtsa_cfg = getattr(mtsa, sparse_cfg, None) + # Get config from mtsa module (e.g., SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_DEFAULT) + mtsa_cfg = getattr(mtsa, sparse_cfg, None) if mtsa_cfg is None: - raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}") + raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}.") else: mtsa_cfg = sparse_cfg diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md index b76a60fee..72c629202 100644 --- a/examples/llm_sparsity/attention_sparsity/README.md +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -145,7 +145,7 @@ custom_config = { }, "*attn*": { # Pattern to match attention modules "method": "flash_skip_softmax", - "threshold": 1e-4, # Fixed threshold (ignored if calibration is used) + "threshold": {"prefill": 1e-3, "decode": 1e-4}, # Phase-specific thresholds (ignored if calibration is used) "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index e189fd0b0..8f9445d17 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -46,12 +46,12 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): description="If True, enables sparse attention. If False, bypasses sparsity.", ) - threshold: float | dict[str, float] = ModeloptField( - default=1e-3, + threshold: dict[str, float] = ModeloptField( + default={"prefill": 1e-3, "decode": 1e-4}, title="Sparsity threshold.", description=( "Threshold for determining which attention values to skip. " - "Can be a float or dict with phase-specific values." + "Must be a dict with 'prefill' and 'decode' keys." ), ) @@ -123,26 +123,24 @@ def validate_block_size(cls, v): @field_validator("threshold") @classmethod def validate_threshold(cls, v): - """Validate threshold is in valid range (0, 1) or dict with valid phases.""" - if isinstance(v, dict): - # Validate phase keys - valid_phases = {"prefill", "decode", "default"} - invalid_keys = set(v.keys()) - valid_phases - if invalid_keys: + """Validate threshold is a dict with valid phases and values in range (0, 1).""" + if not isinstance(v, dict): + raise ValueError( + f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}" + ) + # Validate phase keys + valid_phases = {"prefill", "decode"} + invalid_keys = set(v.keys()) - valid_phases + if invalid_keys: + raise ValueError( + f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" + ) + # Validate all values are in range (0, 1) + for phase, threshold in v.items(): + if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: raise ValueError( - f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" + f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" ) - # Validate all values are in range (0, 1) - for phase, threshold in v.items(): - if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: - raise ValueError( - f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" - ) - elif isinstance(v, (int, float)): - if v <= 0 or v >= 1: - raise ValueError(f"Threshold must be in range (0, 1), got {v}") - else: - raise ValueError(f"Threshold must be a number in range (0, 1) or dict, got {type(v)}") return v @@ -213,6 +211,16 @@ class CalibrationConfig(ModeloptBaseConfig): ), ) + cache_dir: str | None = ModeloptField( + default=None, + title="Cache directory", + description=( + "Directory to cache generated calibration samples. " + "If None, uses MODELOPT_CACHE_DIR env var or ~/.cache/modelopt/sparse_attention/. " + "Caching avoids regenerating samples on repeated calibration runs." + ), + ) + @field_validator("threshold_trials") @classmethod def validate_threshold_trials(cls, v): @@ -372,10 +380,10 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): SKIP_SOFTMAX_CALIB = { "sparse_cfg": { "calibration": { - # "target_sparse_ratio": {"prefill": 0.75, "decode": 0.75}, - "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, + "target_sparse_ratio": {"prefill": 0.9, "decode": 0.9}, "samples": 64, - "max_seqlen": 16384, + "max_seqlen": 65536, + "chunk_size": 4096, }, "*attn*": { "method": "flash_skip_softmax", diff --git a/tests/_test_utils/torch/sparsity/sparse_attention_common.py b/tests/_test_utils/torch/sparsity/sparse_attention_common.py index 5ed079966..b9feea358 100644 --- a/tests/_test_utils/torch/sparsity/sparse_attention_common.py +++ b/tests/_test_utils/torch/sparsity/sparse_attention_common.py @@ -95,7 +95,7 @@ def get_input(cls, d_model=128, seq_len=10, batch_size=2): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 1e-4, + "threshold": {"prefill": 1e-4, "decode": 1e-4}, "br": 128, "bc": 128, "enable": True, diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py index 8f5396241..3f1030f02 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -118,7 +118,7 @@ def test_calibration_simple_model(self, simple_model): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 64, "bc": 64, "backend": "pytorch", @@ -158,7 +158,7 @@ def test_calibration_pytorch_backend(self, simple_model): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "backend": "pytorch", "enable": True, "calibration": { @@ -190,7 +190,7 @@ def test_simplified_calibration(self, simple_model): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "enable": True, "calibration": { "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, @@ -217,7 +217,7 @@ def test_calibration_persistence(self, simple_model): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "enable": True, "calibration": { "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, @@ -264,7 +264,7 @@ def test_calibrated_model_inference(self, simple_model_setup): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "backend": "pytorch", "enable": True, "calibration": { @@ -300,7 +300,7 @@ def test_calibrated_vs_fixed_threshold(self, simple_model_setup): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "enable": True, "calibration": { "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, @@ -316,7 +316,7 @@ def test_calibrated_vs_fixed_threshold(self, simple_model_setup): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "enable": True, } }, @@ -359,7 +359,7 @@ def test_memory_usage(self, simple_model_setup): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "enable": True, "calibration": { "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py index c90b99bba..df4cfaa65 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -66,7 +66,7 @@ def test_load_and_sparsify(self, tinyllama_model): sparse_cfg={ "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -94,7 +94,7 @@ def test_forward_prefill(self, tinyllama_model, tinyllama_tokenizer): config = SparseAttentionConfig( sparse_cfg={ "*attn*": { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "backend": "pytorch", "enable": True, } @@ -124,7 +124,7 @@ def test_forward_decode(self, tinyllama_model): config = SparseAttentionConfig( sparse_cfg={ "*attn*": { - "threshold": 1e-5, # More conservative for decode + "threshold": {"prefill": 1e-3, "decode": 1e-5}, # More conservative for decode "backend": "pytorch", "enable": True, } @@ -163,7 +163,7 @@ def test_gqa_attention(self, tinyllama_model): sparse_config = SparseAttentionConfig( sparse_cfg={ "*attn*": { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "backend": "pytorch", "enable": True, } diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py index d9bbee157..f61988dd5 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py @@ -30,7 +30,7 @@ def test_phase_inference(self): """Test phase detection from attention score shape.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -70,30 +70,11 @@ def test_threshold_update_dict_config(self): method._update_threshold("prefill") assert method.threshold == 1e-3 - def test_threshold_update_static_config(self): - """Test threshold with static float config.""" - method = FlashSkipSoftmax( - { - "threshold": 5e-4, - "br": 128, - "bc": 128, - "backend": "pytorch", - "is_causal": True, - } - ) - - initial_threshold = method.threshold - assert initial_threshold == 5e-4 - - # Should not change for static config - method._update_threshold("decode") - assert method.threshold == 5e-4 - def test_block_reshaping_divisible(self): """Test block reshaping with divisible sequence lengths.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -116,7 +97,7 @@ def test_block_reshaping_with_padding(self): """Test block reshaping with non-divisible lengths.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -139,7 +120,7 @@ def test_correction_factor_calculation_prefill(self): """Test correction factor for prefill phase.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -166,7 +147,7 @@ def test_correction_factor_calculation_decode(self): """Test correction factor for decode phase.""" method = FlashSkipSoftmax( { - "threshold": 1e-5, + "threshold": {"prefill": 1e-3, "decode": 1e-5}, "br": 128, "bc": 128, "backend": "pytorch", @@ -189,7 +170,7 @@ def test_sparsity_statistics(self): """Test sparsity statistics structure.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -210,7 +191,7 @@ def test_block_mask_correctness(self): """Test block mask shape and type.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -229,7 +210,7 @@ def test_block_mask_correctness(self): def test_causal_vs_noncausal(self): """Test total_blocks calculation for causal vs non-causal.""" config_base = { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -252,7 +233,7 @@ def test_calculate_sparsity_assertions(self): """Test calculate_sparsity input validation.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -268,7 +249,7 @@ def test_apply_sparsity_with_mask(self): """Test apply_sparsity with pre-computed mask.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -295,7 +276,7 @@ def test_apply_sparsity_without_mask(self): """Test apply_sparsity calculates mask internally when None.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -315,7 +296,7 @@ def test_name_property(self): """Test method name property.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py index 168418a16..884c82a76 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -293,7 +293,7 @@ def test_calibration_disabled(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 64, "bc": 64, "enable": True, @@ -327,7 +327,7 @@ def test_sparsify_with_calibration_requires_forward_loop(self): }, "*attention*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 64, "bc": 64, "enable": True, @@ -344,7 +344,14 @@ def test_multiple_sparse_modules(self): model = SimpleTransformerEncoder() config = { - "sparse_cfg": {"*attn*": {"threshold": 1e-3, "br": 64, "bc": 64, "enable": True}}, + "sparse_cfg": { + "*attn*": { + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "br": 64, + "bc": 64, + "enable": True, + } + }, } sparse_model = sparsify(model, config) @@ -433,7 +440,7 @@ def test_set_threshold(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.1, + "threshold": {"prefill": 0.1, "decode": 0.1}, "br": 64, "bc": 64, "enable": True, @@ -461,7 +468,7 @@ def test_enable_disable_calibration_mode(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.1, + "threshold": {"prefill": 0.1, "decode": 0.1}, "br": 64, "bc": 64, "enable": True, @@ -497,7 +504,7 @@ def test_extract_calibration_stats_no_stats(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.1, + "threshold": {"prefill": 0.1, "decode": 0.1}, "br": 64, "bc": 64, "enable": True, @@ -523,7 +530,7 @@ def test_enable_calibration_mode_with_existing_stats_manager(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.1, + "threshold": {"prefill": 0.1, "decode": 0.1}, "br": 64, "bc": 64, "enable": True, @@ -559,7 +566,7 @@ def test_extract_calibration_stats_with_phase_filter(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.1, + "threshold": {"prefill": 0.1, "decode": 0.1}, "br": 64, "bc": 64, "enable": True, @@ -601,7 +608,7 @@ def test_extract_calibration_stats_module_without_stats_manager(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.1, + "threshold": {"prefill": 0.1, "decode": 0.1}, "br": 64, "bc": 64, "enable": True, @@ -646,7 +653,7 @@ def test_calibrate_empty_stats(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.1, + "threshold": {"prefill": 0.1, "decode": 0.1}, "br": 64, "bc": 64, "enable": True, @@ -737,7 +744,7 @@ def test_calibrate_no_config(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.1, + "threshold": {"prefill": 0.1, "decode": 0.1}, "br": 64, "bc": 64, "enable": True, @@ -780,7 +787,7 @@ def test_extract_calibration_config_none(self): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": 0.1, + "threshold": {"prefill": 0.1, "decode": 0.1}, } }, } diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 49a8ccd39..d8913f51d 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -69,7 +69,7 @@ def test_pattern_based_replacement(self): "sparse_cfg": { "*self_attn*": { "method": "flash_skip_softmax", - "threshold": 1e-4, + "threshold": {"prefill": 1e-4, "decode": 1e-4}, "br": 128, "bc": 128, "enable": True, @@ -100,7 +100,7 @@ def filter_func(name): "sparse_cfg": { filter_func: { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "enable": True, }, }, @@ -118,7 +118,7 @@ def test_no_matching_modules(self): "sparse_cfg": { "*nonexistent*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "enable": True, }, }, @@ -192,7 +192,7 @@ def test_get_stats_with_stats_manager(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.001, + "threshold": {"prefill": 0.001, "decode": 0.0001}, "br": 64, "bc": 64, "collect_stats": True, # Enable stats collection @@ -228,7 +228,7 @@ def test_get_stats_without_stats_manager(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.001, + "threshold": {"prefill": 0.001, "decode": 0.0001}, "br": 64, "bc": 64, "collect_stats": False, # Disable stats collection @@ -257,7 +257,7 @@ def test_get_threshold_info(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.005, + "threshold": {"prefill": 0.005, "decode": 0.001}, "br": 64, "bc": 64, "enable": True, diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py index aaa058cfd..c958fab22 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -29,23 +29,6 @@ class TestFlashSkipSoftmaxThresholdInfo: """Test FlashSkipSoftmax.get_threshold_info() method.""" - def test_static_threshold(self): - """Test threshold info for static threshold.""" - method = FlashSkipSoftmax( - method_config={ - "threshold": 0.001, - "br": 128, - "bc": 128, - "backend": "pytorch", - "is_causal": True, - } - ) - - info = method.get_threshold_info() - - assert info["type"] == "static" - assert info["value"] == 0.001 - def test_phased_threshold(self): """Test threshold info for phase-specific static thresholds.""" method = FlashSkipSoftmax( @@ -70,7 +53,7 @@ def test_dynamic_calibrated_threshold(self): """Test threshold info for calibrated dynamic threshold.""" method = FlashSkipSoftmax( method_config={ - "threshold": 0.001, + "threshold": {"prefill": 0.001, "decode": 0.0001}, "br": 128, "bc": 128, "backend": "pytorch", @@ -98,7 +81,7 @@ def test_threshold_info_structure(self): """Test that threshold info has expected structure.""" method = FlashSkipSoftmax( method_config={ - "threshold": 0.001, + "threshold": {"prefill": 0.001, "decode": 0.0001}, "br": 128, "bc": 128, "backend": "pytorch", @@ -124,7 +107,7 @@ def test_module_delegates_to_method(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.005, + "threshold": {"prefill": 0.005, "decode": 0.001}, "br": 64, "bc": 64, "enable": True, @@ -147,7 +130,8 @@ def test_module_delegates_to_method(self): info = sparse_module.get_threshold_info() assert info["type"] == "static" - assert info["value"] == 0.005 + assert info["value"]["prefill"] == 0.005 + assert info["value"]["decode"] == 0.001 def test_module_with_calibrated_threshold(self): """Test module reports calibrated threshold correctly.""" @@ -157,7 +141,7 @@ def test_module_with_calibrated_threshold(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.001, + "threshold": {"prefill": 0.001, "decode": 0.0001}, "br": 64, "bc": 64, "enable": True, @@ -192,7 +176,7 @@ def test_module_without_method_instance(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.001, + "threshold": {"prefill": 0.001, "decode": 0.0001}, "br": 64, "bc": 64, "enable": True, @@ -230,7 +214,7 @@ def test_summary_displays_static_threshold(self, capsys): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.001, + "threshold": {"prefill": 0.001, "decode": 0.0001}, "br": 64, "bc": 64, "enable": True, @@ -242,7 +226,8 @@ def test_summary_displays_static_threshold(self, capsys): print_sparse_attention_summary(sparse_model) captured = capsys.readouterr() - assert "threshold=1.00e-03" in captured.out + assert "prefill" in captured.out + assert "decode" in captured.out assert "flash_skip_softmax" in captured.out def test_summary_displays_dynamic_threshold(self, capsys): @@ -257,7 +242,7 @@ def test_summary_displays_dynamic_threshold(self, capsys): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 0.001, + "threshold": {"prefill": 0.001, "decode": 0.0001}, "br": 64, "bc": 64, "enable": True, From 08bbc62b280837377a6b3f9a420104b1d536cd11 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 27 Jan 2026 15:26:41 -0800 Subject: [PATCH 7/9] Move the data folder under example Signed-off-by: Kai Xu --- .../llm_sparsity/attention_sparsity/.gitignore | 2 ++ .../attention_sparsity}/download_ruler_data.sh | 0 .../calibration/ruler_utils.py | 17 +++++++++-------- .../sparsity/attention_sparsity/config.py | 18 +++++++++++++----- .../attention_sparsity/stats_manager.py | 6 +++++- 5 files changed, 29 insertions(+), 14 deletions(-) create mode 100644 examples/llm_sparsity/attention_sparsity/.gitignore rename {modelopt/torch/sparsity/attention_sparsity/calibration => examples/llm_sparsity/attention_sparsity}/download_ruler_data.sh (100%) mode change 100755 => 100644 diff --git a/examples/llm_sparsity/attention_sparsity/.gitignore b/examples/llm_sparsity/attention_sparsity/.gitignore new file mode 100644 index 000000000..480901bac --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/.gitignore @@ -0,0 +1,2 @@ +# Data directory for calibration +data diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh b/examples/llm_sparsity/attention_sparsity/download_ruler_data.sh old mode 100755 new mode 100644 similarity index 100% rename from modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh rename to examples/llm_sparsity/attention_sparsity/download_ruler_data.sh diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py index 9de75c02a..741b621f5 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -99,8 +99,10 @@ 100, ] -# Data directory for RULER calibration files (downloaded via download_ruler_data.sh) -DATA_DIR = Path(__file__).parent / "data" +# Data directory for RULER calibration files (in examples folder) +# Downloaded via examples/llm_sparsity/attention_sparsity/download_ruler_data.sh +_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent.parent +DATA_DIR = _REPO_ROOT / "examples" / "llm_sparsity" / "attention_sparsity" / "data" RULER_URLS_FILE = DATA_DIR / "PaulGrahamEssays_URLs.txt" ESSAYS_DIR = DATA_DIR / "essays" @@ -109,11 +111,10 @@ def _get_data_dir() -> Path: """Get data directory for RULER data. Returns: - Path to data directory under calibration/ (created if doesn't exist) + Path to data directory under examples/llm_sparsity/attention_sparsity/ (created if doesn't exist) """ - data_dir = Path(__file__).parent / "data" - data_dir.mkdir(parents=True, exist_ok=True) - return data_dir + DATA_DIR.mkdir(parents=True, exist_ok=True) + return DATA_DIR def _load_paul_graham_essays_from_files() -> str: @@ -132,7 +133,7 @@ def _load_paul_graham_essays_from_files() -> str: raise RuntimeError( f"Essays directory not found at {ESSAYS_DIR}.\n" "Please run the download script first:\n" - " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + " bash examples/llm_sparsity/attention_sparsity/download_ruler_data.sh" ) essay_files = list(ESSAYS_DIR.glob("*.txt")) @@ -140,7 +141,7 @@ def _load_paul_graham_essays_from_files() -> str: raise RuntimeError( f"No essay files found in {ESSAYS_DIR}.\n" "Please run the download script first:\n" - " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + " bash examples/llm_sparsity/attention_sparsity/download_ruler_data.sh" ) logger.info(f"Loading {len(essay_files)} Paul Graham essays from local files...") diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 8f9445d17..12b343efe 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -147,9 +147,19 @@ def validate_threshold(cls, v): class CalibrationConfig(ModeloptBaseConfig): """Configuration for automatic threshold calibration using RULER dataset. - Calibration learns a dynamic threshold λ = scale_factor / sequence_length that - achieves target sparsity. Supports both prefill and decode phases with per-phase - target sparsity ratios. + Calibration fits an Inverse Power model to determine dynamic thresholds that + achieve target sparsity. The model learns parameters k and p per phase: + + scale_factor = k / (1 - target_sparsity)^p + + At inference time, the threshold is computed as: + + threshold = scale_factor / sequence_length + + Key benefits: + - Target sparsity can be changed at runtime without recalibration + - Threshold automatically adapts to sequence length + - Supports independent prefill and decode phase calibration """ target_sparse_ratio: dict[str, float] = ModeloptField( @@ -216,7 +226,6 @@ class CalibrationConfig(ModeloptBaseConfig): title="Cache directory", description=( "Directory to cache generated calibration samples. " - "If None, uses MODELOPT_CACHE_DIR env var or ~/.cache/modelopt/sparse_attention/. " "Caching avoids regenerating samples on repeated calibration runs." ), ) @@ -403,7 +412,6 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): "SKIP_SOFTMAX_DEFAULT", "CalibrationConfig", "FlashSkipSoftmaxConfig", - "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", "SparseAttentionCfgType", "SparseAttentionConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py index 9f37df61f..b84a3cade 100644 --- a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -130,13 +130,17 @@ def reset(self): def get_calibration_stats(self, phase: str | None = None) -> list[dict]: """Get per-sample calibration statistics, optionally filtered by phase. + Note: Returns historical stats collected while calibration_mode was enabled. + Stats remain accessible even after calibration_mode is disabled. + New stats are only collected when calibration_mode is True. + Args: phase: Optional phase to filter by ('prefill' or 'decode'). If None, returns all stats. Returns: List of per-sample statistics dictionaries. - Empty list if not in calibration mode or no stats for that phase. + Empty list if no stats were collected or no stats match the phase. """ if phase is None: return self.per_sample_stats From da96f1b108e57e948b8e2b554e845b08eade605a Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 27 Jan 2026 16:27:05 -0800 Subject: [PATCH 8/9] Implement Inverse Power calibration for sparse attention Signed-off-by: Kai Xu --- .../llm_sparsity/attention_sparsity/hf_sa.py | 68 +++-- .../calibration/calibrate.py | 125 +++++++-- .../calibration/calibrator.py | 262 +++++++++--------- .../attention_sparsity/calibration/dataset.py | 35 +-- .../sparsity/attention_sparsity/conversion.py | 34 ++- .../methods/flash_skip_softmax.py | 60 ++-- .../test_calibration_gpu.py | 10 +- .../test_sparse_attention_calibration.py | 53 ++-- .../test_sparse_attention_config.py | 35 ++- .../test_sparse_attention_conversion.py | 2 +- .../attention_sparsity/test_threshold_info.py | 58 ++-- 11 files changed, 432 insertions(+), 310 deletions(-) diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 29a2b53aa..eaec7fe4f 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -17,12 +17,12 @@ """Example script for applying sparse attention to HuggingFace models.""" import argparse +import copy import random from pathlib import Path import numpy as np import torch -from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer import modelopt.torch.opt as mto @@ -46,41 +46,13 @@ } -def get_narrativeqa_samples(num_samples=3): - """Load samples from NarrativeQA dataset for testing. - - Args: - num_samples: Number of samples to generate - - Raises: - RuntimeError: If dataset loading fails - ValueError: If no valid samples could be loaded - """ - # Load NarrativeQA dataset with retry logic - try: - dataset = load_dataset("narrativeqa", split="test", streaming=True) - except Exception as e: - raise RuntimeError(f"Failed to load NarrativeQA dataset: {e}") - - samples = [] - for i, item in enumerate(dataset): - if i >= num_samples: - break - - # Combine document context and question - context = item.get("document", {}).get("text", "") - question = item.get("question", {}).get("text", "") - - if context and question: - # Use the full context as-is - prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" - samples.append(prompt) - - if not samples: - raise ValueError("Could not load NarrativeQA samples") - - print(f"Loaded {len(samples)} NarrativeQA samples") - return samples +def get_test_prompts(): + """Get simple test prompts for sample output generation.""" + return [ + "What is the capital of France? Answer:", + "Explain the theory of relativity in simple terms:", + "Write a short poem about the ocean:", + ] def truncate_text(text: str, tokenizer, max_length: int): @@ -130,7 +102,7 @@ def generate_sample_output(model, tokenizer, args): Tuple of (generated_text, input_prompt, input_ids) """ # Load test sample - prompts = get_narrativeqa_samples(num_samples=1) + prompts = get_test_prompts() prompt = prompts[0] # Prepare inputs @@ -198,6 +170,20 @@ def main(args): # Apply sparse attention with optional calibration print(f"\nApplying sparse attention: {args.sparse_attn}") sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + + # Override target_sparse_ratio if provided via CLI + if args.target_sparse_ratio is not None: + sparse_config = copy.deepcopy(sparse_config) + sparse_cfg = sparse_config.get("sparse_cfg", {}) + if isinstance(sparse_cfg, dict) and "calibration" in sparse_cfg: + calibration_cfg = sparse_cfg["calibration"] + if isinstance(calibration_cfg, dict): + calibration_cfg["target_sparse_ratio"] = { + "prefill": args.target_sparse_ratio, + "decode": args.target_sparse_ratio, + } + print(f"Overriding target_sparse_ratio to {args.target_sparse_ratio}") + model = mtsa.sparsify(model, config=sparse_config) print("Sparse attention applied successfully!") @@ -287,5 +273,13 @@ def main(args): help="Directory to export the model with sparse attention applied", ) + # Calibration arguments + parser.add_argument( + "--target_sparse_ratio", + type=float, + default=None, + help="Target sparsity ratio for calibration (0.0 to 1.0). Overrides config value.", + ) + args = parser.parse_args() main(args) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index 515985244..aa3f2a408 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -15,8 +15,11 @@ """Calibration functions for sparse attention.""" +import hashlib +import json import warnings from collections.abc import Callable +from pathlib import Path from typing import Any import torch @@ -30,6 +33,54 @@ from .dataset import RulerDatasetBuilder +def _get_cache_path( + tokenizer_path: str, samples: int, max_seqlen: int, cache_dir: str | None = None +) -> Path: + """Generate cache file path based on calibration parameters. + + Args: + tokenizer_path: Path to tokenizer (used in hash) + samples: Number of calibration samples + max_seqlen: Maximum sequence length + cache_dir: Optional cache directory. If None, uses ~/.cache/modelopt/sparse_attention/ + """ + # Create a hash of the parameters for the cache filename + key = f"{tokenizer_path}_{samples}_{max_seqlen}" + hash_str = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest()[:12] + filename = f"ruler_cache_{samples}s_{max_seqlen}l_{hash_str}.json" + + if cache_dir: + base_dir = Path(cache_dir) + else: + base_dir = Path.home() / ".cache" / "modelopt" / "sparse_attention" + + return base_dir / filename + + +def _load_cached_data(cache_path: Path) -> list[dict[str, Any]] | None: + """Load calibration data from cache if it exists.""" + if cache_path.exists(): + try: + with open(cache_path) as f: + data = json.load(f) + print(f"Loaded {len(data)} cached calibration samples from {cache_path}") + return data + except Exception as e: + print(f"Warning: Failed to load cache: {e}") + return None + + +def _save_cached_data(cache_path: Path, data: list[dict[str, Any]]) -> None: + """Save calibration data to cache.""" + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(data, f) + print(f"Saved calibration samples to cache: {cache_path}") + except Exception as e: + print(f"Warning: Failed to save cache: {e}") + + def _extract_tokenizer_from_model(model: nn.Module) -> str: """Extract tokenizer name/path from model config. @@ -255,18 +306,31 @@ def calibrate_sparse_attention( calibration_data = None if calibrate_prefill or calibrate_decode: - builder = RulerDatasetBuilder( - samples=calib_config.samples, - max_seqlen=calib_config.max_seqlen, - tokenizer_name_or_path=tokenizer, - num_length_bins=calib_config.num_length_bins, - max_length_filter=int(calib_config.max_seqlen * 1.5), + # Try to load from cache first + cache_path = _get_cache_path( + tokenizer, + calib_config.samples, + calib_config.max_seqlen, + cache_dir=calib_config.cache_dir, ) - calibration_data = builder.build_calibration_dataset() - print(f"Generated {len(calibration_data)} calibration samples") + calibration_data = _load_cached_data(cache_path) + + # Generate if not cached + if calibration_data is None: + builder = RulerDatasetBuilder( + samples=calib_config.samples, + max_seqlen=calib_config.max_seqlen, + tokenizer_name_or_path=tokenizer, + num_length_bins=calib_config.num_length_bins, + max_length_filter=int(calib_config.max_seqlen * 1.5), + ) + calibration_data = builder.build_calibration_dataset() + print(f"Generated {len(calibration_data)} calibration samples") + + # Save to cache for future runs + _save_cached_data(cache_path, calibration_data) # Initialize results - threshold_scale_factor: dict[str, float] = {} calibration_results: dict[str, Any] = {} # Run prefill calibration if enabled @@ -282,13 +346,11 @@ def calibrate_sparse_attention( ) prefill_calibrator = DynamicThresholdCalibrator( - target_sparse_ratio=target_dict, threshold_trials=calib_config.threshold_trials, ) prefill_result = prefill_calibrator.calibrate(model, prefill_forward_loop, phase="prefill") - if "scale_factor" in prefill_result: - threshold_scale_factor["prefill"] = prefill_result["scale_factor"] + if "k" in prefill_result and "p" in prefill_result: calibration_results["prefill"] = prefill_result else: warnings.warn("Prefill calibration did not produce valid results") @@ -306,38 +368,57 @@ def calibrate_sparse_attention( ) decode_calibrator = DynamicThresholdCalibrator( - target_sparse_ratio=target_dict, threshold_trials=calib_config.threshold_trials, ) decode_result = decode_calibrator.calibrate(model, decode_forward_loop, phase="decode") - if "scale_factor" in decode_result: - threshold_scale_factor["decode"] = decode_result["scale_factor"] + if "k" in decode_result and "p" in decode_result: calibration_results["decode"] = decode_result else: warnings.warn("Decode calibration did not produce valid results") # Check if any calibration succeeded - if not threshold_scale_factor: + if not calibration_results: warnings.warn("No calibration produced valid results") return {} - # Apply combined threshold_scale_factor dict to all modules + # Extract k and p for each phase + calibration_params: dict[str, dict[str, float]] = {} + for phase in ["prefill", "decode"]: + if phase in calibration_results: + result = calibration_results[phase] + calibration_params[phase] = { + "k": result["k"], + "p": result["p"], + } + + # Apply calibration params to all modules print("\n" + "=" * 60) print("APPLYING CALIBRATION RESULTS") print("=" * 60) - print(f"Applying threshold_scale_factor to {len(sparse_modules)} modules:") - for phase, scale_factor in threshold_scale_factor.items(): - print(f" {phase}: {scale_factor:.6f}") + print(f"Applying calibration to {len(sparse_modules)} modules:") + for phase, params in calibration_params.items(): + result = calibration_results[phase] + print(f" {phase}:") + print(f" Model: scale_factor = {params['k']:.4f} / (1 - sparsity)^{params['p']:.4f}") + print(f" R-squared: {result['r_squared']:.6f}") for module_name, module in sparse_modules: - module._sparse_method_instance.threshold_scale_factor = threshold_scale_factor + module._sparse_method_instance.calibration_params = calibration_params + module._sparse_method_instance.target_sparse_ratio = target_dict # Print final summary print("\nCalibration complete!") + print( + f"Target sparsity: prefill={target_dict.get('prefill', 0):.0%}, " + f"decode={target_dict.get('decode', 0):.0%}" + ) + print("\nTo change target sparsity at inference time, update:") + print(" module._sparse_method_instance.target_sparse_ratio = {'prefill': X, 'decode': Y}") print_sparse_attention_summary(model) return { - "threshold_scale_factor": threshold_scale_factor, + "calibration_params": calibration_params, + "target_sparse_ratio": target_dict, "calibration_results": calibration_results, } diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index 149909c56..edf7c9079 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -16,13 +16,14 @@ """Calibration framework for sparse attention methods.""" import warnings +from collections import defaultdict from collections.abc import Callable -from dataclasses import dataclass from typing import Any import numpy as np import torch import torch.nn as nn +from scipy.optimize import curve_fit from tqdm import tqdm from ..sparse_attention import SparseAttentionModule @@ -30,36 +31,38 @@ class DynamicThresholdCalibrator: - """Dynamic threshold calibrator using length-based linear relationship. + """Dynamic threshold calibrator using Inverse Power model. - Implements calibration algorithm: - 1. Find hyperparameter 'a' where threshold λ = a / context_length - 2. Use dataset with different lengths and test multiple thresholds - 3. For each sample, find optimal threshold closest to target sparsity - 4. Use linear regression to fit: threshold = a * (1/length) - """ + Calibration Algorithm: + 1. For each threshold λ_j in threshold_trials: + - Run ALL samples through forward_loop + - For each sample i with length L_i, collect sparsity S_ij + - Compute scale_factor_ij = λ_j × L_i + + 2. Fit Inverse Power model to ALL individual (sf_ij, S_ij) pairs: + scale_factor = k / (1 - sparsity)^p - @dataclass - class SampleSparsity: - """Sparsity results for a single calibration sample.""" + 3. Return fitted k and p parameters (model-specific) - length: int - threshold_sparsities: dict[float, float] + At inference time (user specifies target_sparsity S*): + scale_factor = k / (1 - S*)^p + threshold = scale_factor / seqlen + + Key insight: Using all individual data points (N_thresholds × N_samples) + instead of per-threshold averages provides more accurate fitting without + additional calibration time cost. + """ def __init__( self, - target_sparse_ratio: dict[str, float] | None = None, threshold_trials: list[float] | None = None, ): """Initialize dynamic threshold calibrator. Args: - target_sparse_ratio: Target sparsity ratio dict with 'prefill' and 'decode' keys. - Each value should be in range (0.0 to 1.0). Set to 0.0 to skip that phase. - threshold_trials: List of thresholds to try during calibration + threshold_trials: List of thresholds to try during calibration. + Should span a range that achieves sparsities from ~10% to ~95%. """ - self.target_sparse_ratio = target_sparse_ratio - # Default threshold trials if not provided self.threshold_trials = threshold_trials or [ 1e-6, @@ -78,18 +81,28 @@ def __init__( 3e-1, 5e-1, 7e-1, + 8e-1, + 9e-1, + 9.5e-1, + 9.9e-1, ] - # Statistics tracking - self.sparsity_results = [] - def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dict[str, Any]: - """Find optimal 'a' parameter for length-based threshold. + """Calibrate k and p parameters for Inverse Power model. Algorithm: - 1. Test all threshold trials by running forward_loop multiple times - 2. For each sample, find optimal threshold closest to target sparsity - 3. Use regression to find 'a' in: threshold = a / length + 1. For each threshold λ_j in threshold_trials: + - Run ALL samples, collect sparsities S_ij for each sample i + - Compute scale_factor_ij = λ_j × L_i (where L_i is sample length) + + 2. Fit Inverse Power model to ALL (sf_ij, S_ij) pairs: + scale_factor = k / (1 - sparsity)^p + + 3. Return fitted k and p parameters + + At inference time (user specifies target_sparsity S*): + scale_factor = k / (1 - S*)^p + threshold = scale_factor / seqlen Args: model: The model with sparse attention modules @@ -97,48 +110,24 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic phase: Phase to calibrate ('prefill' or 'decode') Returns: - Dict with calibration results including scale_factor, or empty dict if failed + Dict with calibration results including k, p, r_squared, and num_data_points """ - if self.target_sparse_ratio is None: - raise RuntimeError("target_sparse_ratio must be provided") - target_sparsity = self.target_sparse_ratio[phase] - # Extract attention modules attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] if not attention_modules: raise ValueError("No sparse attention modules found for calibration") - print(f"Starting dynamic threshold calibration ({phase} phase)") - print(f"Target sparsity: {target_sparsity}") + print(f"Starting Inverse Power model calibration ({phase} phase)") print(f"Threshold trials: {len(self.threshold_trials)}") - # Stage 1: Collect sparsity for all sample-threshold pairs - print(f"\nStage 1: Collecting {phase} sparsity data...") + # Stage 1: Collect ALL (scale_factor, sparsity) pairs for all thresholds and samples + print(f"\nStage 1: Collecting {phase} sparsity data for all thresholds...") - # Run first threshold to discover samples and initialize results - self._set_threshold(attention_modules, self.threshold_trials[0]) - self._enable_calibration_mode(attention_modules) - with torch.no_grad(): - forward_loop(model) - per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) - self._disable_calibration_mode(attention_modules) - - if not per_sample_stats: - warnings.warn(f"No {phase} phase statistics collected. Check forward loop.") - return {} - - # Initialize sparsity_results with sample info - self.sparsity_results = [ - self.SampleSparsity( - length=stat["sample_length"], - threshold_sparsities={self.threshold_trials[0]: stat["sparsity"]}, - ) - for stat in per_sample_stats - ] + # Collect ALL individual data points (not averaged) + all_data_points = [] # List of {"threshold", "length", "scale_factor", "sparsity"} - # Collect remaining thresholds - for threshold in tqdm(self.threshold_trials[1:], desc=f"Testing thresholds ({phase})"): + for threshold in tqdm(self.threshold_trials, desc=f"Testing thresholds ({phase})"): self._set_threshold(attention_modules, threshold) self._enable_calibration_mode(attention_modules) with torch.no_grad(): @@ -146,90 +135,115 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) self._disable_calibration_mode(attention_modules) - for sample_idx, sample_stat in enumerate(per_sample_stats): - if sample_idx < len(self.sparsity_results): - self.sparsity_results[sample_idx].threshold_sparsities[threshold] = sample_stat[ - "sparsity" - ] + if not per_sample_stats: + continue + + # Collect individual (scale_factor, sparsity) pairs for each sample + for sample_stat in per_sample_stats: + length = sample_stat["sample_length"] + sparsity = sample_stat["sparsity"] + scale_factor = threshold * length + + all_data_points.append( + { + "threshold": threshold, + "length": length, + "scale_factor": scale_factor, + "sparsity": sparsity, + } + ) - if not self.sparsity_results: - warnings.warn(f"No valid {phase} sparsity measurements collected during calibration") + if len(all_data_points) < 10: + warnings.warn( + f"Not enough data points for {phase} calibration. " + f"Got {len(all_data_points)}, need at least 10." + ) return {} - print(f"Collected statistics for {len(self.sparsity_results)} samples") + print(f"Collected {len(all_data_points)} individual (scale_factor, sparsity) pairs") - # Stage 2: Find optimal threshold for each sample and compute 'a' - print(f"\nStage 2: Finding 'a' parameter for target sparsity {target_sparsity:.2f}") + # Stage 2: Fit Inverse Power model: scale_factor = k / (1 - sparsity)^p + print("\nStage 2: Fitting Inverse Power model to all data points...") - # Find optimal threshold for each sample - optimal_pairs = [] - for sample_result in self.sparsity_results: - # Find threshold closest to target sparsity - best_threshold, achieved_sparsity = min( - sample_result.threshold_sparsities.items(), - key=lambda item: abs(item[1] - target_sparsity), - ) + # Extract data for fitting + scale_factors = np.array([p["scale_factor"] for p in all_data_points]) + sparsities = np.array([p["sparsity"] for p in all_data_points]) - optimal_pairs.append( - { - "length": sample_result.length, - "optimal_threshold": best_threshold, - "achieved_sparsity": achieved_sparsity, - "target_sparsity": target_sparsity, - } - ) + # Filter out invalid sparsities (must be in (0, 1)) + valid_mask = (sparsities > 0.01) & (sparsities < 0.99) + scale_factors = scale_factors[valid_mask] + sparsities = sparsities[valid_mask] - if not optimal_pairs: + if len(scale_factors) < 3: warnings.warn( - f"No optimal threshold pairs found for {phase} target sparsity {target_sparsity}. " - f"Collected {len(self.sparsity_results)} samples but none achieved target sparsity." + f"Not enough valid data points after filtering. Got {len(scale_factors)}." ) return {} - # Linear regression: threshold = a * (1/length) - lengths = np.array([p["length"] for p in optimal_pairs]) - thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) - - # X = 1/length, Y = threshold - x = 1.0 / lengths - y = thresholds - - # Least squares: scale_factor = sum(x*y) / sum(x^2) - scale_factor = np.sum(x * y) / np.sum(x**2) - - # Calculate statistics - scale_factors_per_sample = y * lengths - scale_factor_std = np.std(scale_factors_per_sample) + # Define Inverse Power model: sf = k / (1 - S)^p + def inverse_power(sparsity, k, p): + return k / np.power(1 - sparsity, p) + + # Fit the model + try: + popt, pcov = curve_fit( + inverse_power, + sparsities, + scale_factors, + p0=[100, 1.5], # Initial guess + bounds=([0.1, 0.1], [1e7, 10]), # Bounds for k and p + maxfev=10000, + ) + k, p = popt + except Exception as e: + warnings.warn(f"Curve fitting failed: {e}") + return {} - # Calculate R-squared for quality metric - y_pred = scale_factor * x - ss_res = np.sum((y - y_pred) ** 2) - ss_tot = np.sum((y - np.mean(y)) ** 2) + # Calculate R-squared + pred_scale_factors = inverse_power(sparsities, k, p) + ss_res = np.sum((scale_factors - pred_scale_factors) ** 2) + ss_tot = np.sum((scale_factors - np.mean(scale_factors)) ** 2) r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 - # Calculate average achieved sparsity - avg_achieved_sparsity = np.mean([p["achieved_sparsity"] for p in optimal_pairs]) - - print(f"\n{phase.capitalize()} Calibration Results:") - print(f" Threshold scale factor: {scale_factor:.6f} (std: {scale_factor_std:.6f})") - print(f" R-squared: {r_squared:.4f}") - print( - f" Average achieved sparsity: {avg_achieved_sparsity:.2%} (target: {target_sparsity:.2%})" - ) - print(f"\nExample thresholds with λ = {scale_factor:.6f} / length:") - for length in [1024, 2048, 4096, 8192, 16384]: - print(f" Length {length:5d}: threshold = {scale_factor / length:.2e}") + print(f"\n{phase.capitalize()} Calibration Results (Inverse Power Model):") + print(" Model: scale_factor = k / (1 - sparsity)^p") + print(f" Fitted k: {k:.4f}") + print(f" Fitted p: {p:.4f}") + print(f" R-squared: {r_squared:.6f}") + print(f" Data points used: {int(np.sum(valid_mask))} / {len(all_data_points)}") + + # Show scale_factor for various target sparsities + print("\nScale factors for different target sparsities:") + print(f" {'Target':<10} {'Scale Factor':<15}") + print(f" {'-' * 10} {'-' * 15}") + for target in [0.5, 0.7, 0.8, 0.9, 0.95]: + sf = k / (1 - target) ** p + print(f" {target:<10.0%} {sf:<15.2f}") + + # Print calibration data summary by threshold + print("\nCalibration data summary (per threshold):") + print(f" {'Threshold':<12} {'Avg SF':<12} {'Avg Sparsity':<12} {'Samples':<8}") + print(f" {'-' * 12} {'-' * 12} {'-' * 12} {'-' * 8}") + + # Group by threshold for summary + by_threshold = defaultdict(list) + for point in all_data_points: + by_threshold[point["threshold"]].append(point) + + for threshold in sorted(by_threshold.keys()): + points = by_threshold[threshold] + avg_sf = np.mean([p["scale_factor"] for p in points]) + avg_s = np.mean([p["sparsity"] for p in points]) + print(f" {threshold:<12.4f} {avg_sf:<12.2f} {avg_s:<12.2%} {len(points):<8}") return { "phase": phase, - "scale_factor": scale_factor, - "scale_factor_std": scale_factor_std, - "r_squared": r_squared, - "num_samples": len(optimal_pairs), - "target_sparsity": target_sparsity, - "avg_achieved_sparsity": avg_achieved_sparsity, - "optimal_pairs": optimal_pairs, - "calibration_type": "length_based_dynamic", + "k": float(k), + "p": float(p), + "r_squared": float(r_squared), + "num_data_points": int(np.sum(valid_mask)), + "total_samples": len(all_data_points), + "calibration_type": "inverse_power", } def _enable_calibration_mode(self, modules: list[nn.Module]): diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py index dc46413c1..221ea2344 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -20,6 +20,8 @@ from dataclasses import dataclass from typing import Any +from tqdm import tqdm + from . import ruler_utils @@ -242,26 +244,27 @@ def build_calibration_dataset(self) -> list[dict[str, Any]]: Returns: List of calibration samples with 'input' and 'length' fields """ - from tqdm import tqdm - all_samples = [] - # Generate calibration samples - for num_samples, target_length in tqdm( - zip(self.samples_per_length, self.target_lengths), - desc="Generating RULER calibration samples", - total=len(self.target_lengths), - ): - samples_per_task = max(num_samples // len(self.subtasks), 1) - - # Generate equal samples for each task - for task_name in self.subtasks: - for sample_idx in range(samples_per_task): - sample = self._generate_sample(task_name, target_length, sample_idx) - if sample and sample["length"] <= self.max_length_filter: - all_samples.append(sample) + print( + f"Generating {self.total_samples} calibration samples " + f"across {len(self.target_lengths)} length bins: {self.target_lengths}" + ) + + # Generate calibration samples with sample-level progress + with tqdm(total=self.total_samples, desc="Generating RULER samples") as pbar: + for num_samples, target_length in zip(self.samples_per_length, self.target_lengths): + samples_per_task = max(num_samples // len(self.subtasks), 1) + + for task_name in self.subtasks: + for sample_idx in range(samples_per_task): + sample = self._generate_sample(task_name, target_length, sample_idx) + if sample and sample["length"] <= self.max_length_filter: + all_samples.append(sample) + pbar.update(1) random.shuffle(all_samples) + print(f"Generated {len(all_samples)} valid samples") return all_samples def _generate_sample( diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 006c99eea..716149bc9 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -244,23 +244,27 @@ def update_sparse_attention_metadata( def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: """Extract sparse attention config for export to config.json. - Extracts the global threshold_scale_factor from the first sparse attention - module that has calibrated thresholds. + Extracts the calibration parameters (k, p) and target_sparse_ratio from the first + sparse attention module that has calibrated thresholds. Args: model: Model with sparse attention applied Returns: Dictionary with sparse attention config, or None if no calibrated config found. - Format: {"threshold_scale_factor": {"prefill": float, "decode": float}} + Contains "calibration_params" with k and p per phase, and "target_sparse_ratio". """ for module in model.modules(): if isinstance(module, SparseAttentionModule): - threshold_scale_factor = getattr( - module._sparse_method_instance, "threshold_scale_factor", None + calibration_params = getattr(module._sparse_method_instance, "calibration_params", None) + target_sparse_ratio = getattr( + module._sparse_method_instance, "target_sparse_ratio", None ) - if threshold_scale_factor is not None: - return {"threshold_scale_factor": threshold_scale_factor} + if calibration_params is not None: + return { + "calibration_params": calibration_params, + "target_sparse_ratio": target_sparse_ratio, + } return None @@ -327,11 +331,17 @@ def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Cal def _format_threshold(info: dict) -> str: """Format threshold info for display.""" t = info.get("type") - if t == "dynamic": - # Per-phase calibrated threshold: λ = scale_factor[phase] / length - scale_factors = info.get("scale_factors", {}) - parts = [f"{phase}={sf:.2f}" for phase, sf in scale_factors.items()] - return f"λ={{{', '.join(parts)}}}" + if t == "dynamic_calibrated": + # Inverse Power model: threshold = k / (1 - sparsity)^p / seqlen + params = info.get("calibration_params", {}) + target = info.get("target_sparse_ratio", {}) + parts = [] + for phase in ["prefill", "decode"]: + if phase in params: + k, p = params[phase]["k"], params[phase]["p"] + s = target.get(phase, 0.5) + parts.append(f"{phase}: k={k:.1f}, p={p:.2f}, target={s:.0%}") + return f"calibrated({', '.join(parts)})" if t == "static": v = info.get("value") if isinstance(v, dict): diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 7944510c5..c7e168abc 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -56,13 +56,11 @@ def __init__(self, method_config: dict | None = None): # Optional parameters not in Pydantic config self.phase = config.get("phase", None) - # Initialize threshold - if isinstance(self.threshold_config, dict): - self.threshold = self.threshold_config.get( - "default", self.threshold_config.get("prefill", 1e-4) - ) - else: - self.threshold = self.threshold_config + # Initialize threshold from dict config (prefill phase as default) + self.threshold = self.threshold_config.get("prefill", 1e-3) + + # Calibration mode flag (prevents threshold updates during calibration) + self._calibration_mode = False def set_calibration_mode(self, enabled: bool): """Set calibration mode to prevent _update_threshold from modifying the threshold.""" @@ -70,10 +68,7 @@ def set_calibration_mode(self, enabled: bool): def _update_threshold(self, phase: str): """Update threshold based on phase.""" - if isinstance(self.threshold_config, dict): - self.threshold = self.threshold_config.get( - phase, self.threshold_config.get("default", self.threshold) - ) + self.threshold = self.threshold_config.get(phase, self.threshold) def _infer_phase(self, attention_scores: torch.Tensor) -> str: """Infer phase from attention scores shape.""" @@ -138,10 +133,21 @@ def calc_correction_factor_and_p( batch_size, num_heads, seq_q, seq_k = attn_weights.shape # Calculate threshold - threshold_scale_factor = getattr(self, "threshold_scale_factor", None) - if threshold_scale_factor is not None and phase in threshold_scale_factor: - # Per-phase calibrated threshold: λ = scale_factor[phase] / length - log_threshold = np.log(threshold_scale_factor[phase] / seq_k) + calibration_params = getattr(self, "calibration_params", None) + target_sparse_ratio = getattr(self, "target_sparse_ratio", None) + + if ( + calibration_params is not None + and phase in calibration_params + and target_sparse_ratio is not None + ): + # Use calibrated k, p to compute dynamic threshold + # scale_factor = k / (1 - target_sparsity)^p + k = calibration_params[phase]["k"] + p = calibration_params[phase]["p"] + target_sparsity = target_sparse_ratio.get(phase, 0.5) + scale_factor = k / ((1 - target_sparsity) ** p) + log_threshold = np.log(scale_factor / seq_k) else: # Use static threshold from config (no calibration or phase not calibrated) log_threshold = np.log(self.threshold) @@ -313,23 +319,31 @@ def get_threshold_info(self) -> dict[str, Any]: Returns: Dictionary with threshold configuration and calibration info. """ - threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + calibration_params = getattr(self, "calibration_params", None) + target_sparse_ratio = getattr(self, "target_sparse_ratio", None) - if threshold_scale_factor is not None: - # Per-phase calibrated dynamic threshold - example_lengths = [1024, 2048, 4096, 8192] + if calibration_params is not None and target_sparse_ratio is not None: + # Per-phase calibrated dynamic threshold using Inverse Power model + example_lengths = [1024, 4096, 16384, 65536, 131072] phase_info = {} - for phase, scale_factor in threshold_scale_factor.items(): + for phase, params in calibration_params.items(): + k, p = params["k"], params["p"] + target_sparsity = target_sparse_ratio.get(phase, 0.5) + scale_factor = k / ((1 - target_sparsity) ** p) phase_info[phase] = { + "k": k, + "p": p, + "target_sparsity": target_sparsity, "scale_factor": scale_factor, "example_thresholds": { length: scale_factor / length for length in example_lengths }, } return { - "type": "dynamic", - "scale_factors": threshold_scale_factor, - "formula": "λ[phase] / length", + "type": "dynamic_calibrated", + "formula": "threshold = k / (1 - target_sparsity)^p / seqlen", + "calibration_params": calibration_params, + "target_sparse_ratio": target_sparse_ratio, "phases": phase_info, } else: diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py index 3f1030f02..dd7c4c6f5 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -143,12 +143,14 @@ def forward_loop(model): sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] assert len(sparse_modules) > 0 - # Verify calibration was applied + # Verify calibration was applied (Inverse Power model params) for module in sparse_modules: method = module._sparse_method_instance - # Check if calibrated threshold scale factor is set - if hasattr(method, "threshold_scale_factor") and method.threshold_scale_factor: - assert method.threshold_scale_factor > 0 + # Check if calibration params (k, p) are set + if hasattr(method, "calibration_params") and method.calibration_params: + for params in method.calibration_params.values(): + assert "k" in params and params["k"] > 0 + assert "p" in params and params["p"] > 0 def test_calibration_pytorch_backend(self, simple_model): """Test calibration with pytorch backend.""" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py index 884c82a76..c093356bb 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -312,7 +312,7 @@ def test_calibration_disabled(self): for module in sparse_model.modules(): if isinstance(module, SparseAttentionModule): method = module._sparse_method_instance - assert not getattr(method, "threshold_scale_factor", None) + assert not getattr(method, "calibration_params", None) def test_sparsify_with_calibration_requires_forward_loop(self): """Test that calibration requires forward_loop or proper model config.""" @@ -454,7 +454,7 @@ def test_set_threshold(self): assert len(modules) > 0 # Create calibrator and set threshold - calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + calibrator = DynamicThresholdCalibrator() calibrator._set_threshold(modules, 0.05) # Verify threshold was set @@ -479,7 +479,7 @@ def test_enable_disable_calibration_mode(self): modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] - calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + calibrator = DynamicThresholdCalibrator() # Enable calibration mode calibrator._enable_calibration_mode(modules) @@ -515,7 +515,7 @@ def test_extract_calibration_stats_no_stats(self): modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] - calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + calibrator = DynamicThresholdCalibrator() # Extract stats without running any forward passes stats = calibrator._extract_calibration_stats(modules) @@ -546,7 +546,7 @@ def test_enable_calibration_mode_with_existing_stats_manager(self): for module in modules: assert module._stats_manager is not None - calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + calibrator = DynamicThresholdCalibrator() # Disable stats manager first for module in modules: @@ -577,7 +577,7 @@ def test_extract_calibration_stats_with_phase_filter(self): modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] - calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + calibrator = DynamicThresholdCalibrator() # Enable calibration mode and manually add some stats calibrator._enable_calibration_mode(modules) @@ -620,7 +620,7 @@ def test_extract_calibration_stats_module_without_stats_manager(self): modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] - calibrator = DynamicThresholdCalibrator(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + calibrator = DynamicThresholdCalibrator() # Stats manager should be None for module in modules: @@ -636,7 +636,6 @@ def test_calibrate_no_sparse_modules(self): # Don't apply sparse attention calibrator = DynamicThresholdCalibrator( - target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, threshold_trials=[0.001, 0.01], ) @@ -663,7 +662,6 @@ def test_calibrate_empty_stats(self): sparse_model = sparsify(model, config) calibrator = DynamicThresholdCalibrator( - target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, threshold_trials=[0.001], # Only one threshold for speed ) @@ -677,12 +675,10 @@ def empty_forward_loop(m): def test_calibrator_default_threshold_trials_values(self): """Test that default threshold trials have expected values.""" - calibrator = DynamicThresholdCalibrator( - target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, - ) + calibrator = DynamicThresholdCalibrator() - # Should have 12 default trials - assert len(calibrator.threshold_trials) == 12 + # Should have 20 default trials (expanded range for Inverse Power model) + assert len(calibrator.threshold_trials) == 20 # Check specific values expected_trials = [ @@ -695,9 +691,17 @@ def test_calibrator_default_threshold_trials_values(self): 1e-3, 5e-3, 1e-2, + 2e-2, 5e-2, 1e-1, + 2e-1, + 3e-1, 5e-1, + 7e-1, + 8e-1, + 9e-1, + 9.5e-1, + 9.9e-1, ] assert calibrator.threshold_trials == expected_trials @@ -705,32 +709,11 @@ def test_calibrator_custom_threshold_trials(self): """Test calibrator with custom threshold trials.""" custom_trials = [0.001, 0.005, 0.01, 0.05, 0.1] calibrator = DynamicThresholdCalibrator( - target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, threshold_trials=custom_trials, ) assert calibrator.threshold_trials == custom_trials - def test_calibrator_sparsity_results_initialization(self): - """Test that sparsity_results is initialized as empty list.""" - calibrator = DynamicThresholdCalibrator( - target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, - ) - - assert calibrator.sparsity_results == [] - - def test_sample_sparsity_dataclass(self): - """Test SampleSparsity dataclass.""" - sample = DynamicThresholdCalibrator.SampleSparsity( - length=1024, - threshold_sparsities={0.001: 0.3, 0.01: 0.5, 0.1: 0.7}, - ) - - assert sample.length == 1024 - assert sample.threshold_sparsities[0.001] == 0.3 - assert sample.threshold_sparsities[0.01] == 0.5 - assert sample.threshold_sparsities[0.1] == 0.7 - class TestCalibrateFunction: """Test calibrate_sparse_attention function.""" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py index 1824825f9..ac1bc650d 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -35,13 +35,13 @@ def test_valid_config(self): """Test creating valid config.""" config = SparseAttentionAttributeConfig( method="flash_skip_softmax", - threshold=1e-4, + threshold={"prefill": 1e-4, "decode": 1e-4}, br=128, bc=128, enable=True, ) assert config.method == "flash_skip_softmax" - assert config.threshold == 1e-4 + assert config.threshold == {"prefill": 1e-4, "decode": 1e-4} assert config.br == 128 assert config.bc == 128 @@ -65,18 +65,22 @@ def test_block_size_validation_large(self): assert config.br == 2048 def test_threshold_validation_range(self): - """Test threshold must be in range (0, 1).""" - with pytest.raises(ValidationError, match="Threshold must be in range"): - SparseAttentionAttributeConfig(threshold=0) + """Test threshold dict values must be in range (0, 1).""" + # Zero value + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 0, "decode": 1e-4}) - with pytest.raises(ValidationError, match="Threshold must be in range"): - SparseAttentionAttributeConfig(threshold=-0.1) + # Negative value + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -0.1, "decode": 1e-4}) - with pytest.raises(ValidationError, match="Threshold must be in range"): - SparseAttentionAttributeConfig(threshold=1.0) + # Value equals 1.0 + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0, "decode": 1e-4}) - with pytest.raises(ValidationError, match="Threshold must be in range"): - SparseAttentionAttributeConfig(threshold=1.5) + # Value greater than 1.0 + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.5, "decode": 1e-4}) def test_threshold_validation_dict(self): """Test threshold dict validation.""" @@ -97,8 +101,13 @@ def test_threshold_validation_dict(self): SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) def test_threshold_validation_type(self): - """Test threshold type validation.""" - with pytest.raises(ValidationError, match="Input should be a valid"): + """Test threshold must be a dict (not single value or string).""" + # Single float value not allowed + with pytest.raises(ValidationError, match="Input should be a valid dictionary"): + SparseAttentionAttributeConfig(threshold=1e-4) + + # String not allowed + with pytest.raises(ValidationError, match="Input should be a valid dictionary"): SparseAttentionAttributeConfig(threshold="invalid") diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index d8913f51d..2a283faa0 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -275,5 +275,5 @@ def test_get_threshold_info(self): assert isinstance(info, dict) assert "type" in info assert info["type"] == "static" - assert info["value"] == 0.005 + assert info["value"] == {"prefill": 0.005, "decode": 0.001} break diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py index c958fab22..b02056480 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -61,21 +61,26 @@ def test_dynamic_calibrated_threshold(self): } ) - # Simulate calibration setting per-phase scale factors - method.threshold_scale_factor = {"prefill": 437.5, "decode": 500.0} + # Simulate calibration setting k and p parameters + method.calibration_params = { + "prefill": {"k": 150.0, "p": 1.5}, + "decode": {"k": 200.0, "p": 1.8}, + } + method.target_sparse_ratio = {"prefill": 0.9, "decode": 0.9} info = method.get_threshold_info() - assert info["type"] == "dynamic" - assert info["scale_factors"] == {"prefill": 437.5, "decode": 500.0} - assert info["formula"] == "λ[phase] / length" + assert info["type"] == "dynamic_calibrated" + assert info["formula"] == "threshold = k / (1 - target_sparsity)^p / seqlen" + assert "calibration_params" in info + assert "target_sparse_ratio" in info assert "phases" in info assert "prefill" in info["phases"] assert "decode" in info["phases"] - # Check example thresholds for prefill - prefill_examples = info["phases"]["prefill"]["example_thresholds"] - assert abs(prefill_examples[1024] - 437.5 / 1024) < 1e-6 - assert abs(prefill_examples[2048] - 437.5 / 2048) < 1e-6 + # Check that k and p are in phase info + assert info["phases"]["prefill"]["k"] == 150.0 + assert info["phases"]["prefill"]["p"] == 1.5 + assert info["phases"]["prefill"]["target_sparsity"] == 0.9 def test_threshold_info_structure(self): """Test that threshold info has expected structure.""" @@ -151,13 +156,17 @@ def test_module_with_calibrated_threshold(self): sparse_model = sparsify(model, config) - # Find module and set calibrated threshold (per-phase dict format) + # Find module and set calibrated params (Inverse Power model) module = None for module in sparse_model.modules(): if isinstance(module, SparseAttentionModule): - module._sparse_method_instance.threshold_scale_factor = { - "prefill": 500.0, - "decode": 500.0, + module._sparse_method_instance.calibration_params = { + "prefill": {"k": 150.0, "p": 1.5}, + "decode": {"k": 200.0, "p": 1.8}, + } + module._sparse_method_instance.target_sparse_ratio = { + "prefill": 0.9, + "decode": 0.9, } break @@ -165,8 +174,8 @@ def test_module_with_calibrated_threshold(self): # Get threshold info info = module.get_threshold_info() - assert info["type"] == "dynamic" - assert info["scale_factors"] == {"prefill": 500.0, "decode": 500.0} + assert info["type"] == "dynamic_calibrated" + assert info["calibration_params"]["prefill"]["k"] == 150.0 def test_module_without_method_instance(self): """Test get_threshold_info when sparse method instance doesn't exist.""" @@ -252,19 +261,22 @@ def test_summary_displays_dynamic_threshold(self, capsys): sparse_model = sparsify(model, config) - # Set calibrated threshold (per-phase dict format) + # Set calibrated params (Inverse Power model) for module in sparse_model.modules(): if isinstance(module, SparseAttentionModule): - module._sparse_method_instance.threshold_scale_factor = { - "prefill": 437.5, - "decode": 500.0, + module._sparse_method_instance.calibration_params = { + "prefill": {"k": 150.0, "p": 1.5}, + "decode": {"k": 200.0, "p": 1.8}, + } + module._sparse_method_instance.target_sparse_ratio = { + "prefill": 0.9, + "decode": 0.9, } print_sparse_attention_summary(sparse_model) captured = capsys.readouterr() - # Output format: λ={prefill=437.50, decode=500.00} - assert "λ=" in captured.out - assert "prefill=" in captured.out - assert "decode=" in captured.out + # Output should show calibrated info assert "flash_skip_softmax" in captured.out + assert "prefill" in captured.out + assert "decode" in captured.out From a5136e891025e7f1044c4dee85d49804ef7f0dd6 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 29 Jan 2026 14:14:53 -0800 Subject: [PATCH 9/9] Switch to exponential model for fitting from inverse power Signed-off-by: Kai Xu --- .../calibration/calibrate.py | 12 +-- .../calibration/calibrator.py | 74 ++++++++------- .../calibration/ruler_utils.py | 6 +- .../sparsity/attention_sparsity/config.py | 7 +- .../sparsity/attention_sparsity/conversion.py | 91 +++++++++++++++---- .../methods/flash_skip_softmax.py | 22 ++--- .../test_calibration_gpu.py | 8 +- .../test_sparse_attention_calibration.py | 2 +- .../attention_sparsity/test_threshold_info.py | 28 +++--- 9 files changed, 156 insertions(+), 94 deletions(-) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index aa3f2a408..d469c40d0 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -350,7 +350,7 @@ def calibrate_sparse_attention( ) prefill_result = prefill_calibrator.calibrate(model, prefill_forward_loop, phase="prefill") - if "k" in prefill_result and "p" in prefill_result: + if "a" in prefill_result and "b" in prefill_result: calibration_results["prefill"] = prefill_result else: warnings.warn("Prefill calibration did not produce valid results") @@ -372,7 +372,7 @@ def calibrate_sparse_attention( ) decode_result = decode_calibrator.calibrate(model, decode_forward_loop, phase="decode") - if "k" in decode_result and "p" in decode_result: + if "a" in decode_result and "b" in decode_result: calibration_results["decode"] = decode_result else: warnings.warn("Decode calibration did not produce valid results") @@ -382,14 +382,14 @@ def calibrate_sparse_attention( warnings.warn("No calibration produced valid results") return {} - # Extract k and p for each phase + # Extract a and b for each phase calibration_params: dict[str, dict[str, float]] = {} for phase in ["prefill", "decode"]: if phase in calibration_results: result = calibration_results[phase] calibration_params[phase] = { - "k": result["k"], - "p": result["p"], + "a": result["a"], + "b": result["b"], } # Apply calibration params to all modules @@ -400,7 +400,7 @@ def calibrate_sparse_attention( for phase, params in calibration_params.items(): result = calibration_results[phase] print(f" {phase}:") - print(f" Model: scale_factor = {params['k']:.4f} / (1 - sparsity)^{params['p']:.4f}") + print(f" Model: scale_factor = {params['a']:.6f} * exp({params['b']:.4f} * sparsity)") print(f" R-squared: {result['r_squared']:.6f}") for module_name, module in sparse_modules: diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index edf7c9079..8f3759352 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -31,7 +31,7 @@ class DynamicThresholdCalibrator: - """Dynamic threshold calibrator using Inverse Power model. + """Dynamic threshold calibrator using Exponential model. Calibration Algorithm: 1. For each threshold λ_j in threshold_trials: @@ -39,13 +39,13 @@ class DynamicThresholdCalibrator: - For each sample i with length L_i, collect sparsity S_ij - Compute scale_factor_ij = λ_j × L_i - 2. Fit Inverse Power model to ALL individual (sf_ij, S_ij) pairs: - scale_factor = k / (1 - sparsity)^p + 2. Fit Exponential model to ALL individual (sf_ij, S_ij) pairs: + scale_factor = a * exp(b * sparsity) - 3. Return fitted k and p parameters (model-specific) + 3. Return fitted a and b parameters At inference time (user specifies target_sparsity S*): - scale_factor = k / (1 - S*)^p + scale_factor = a * exp(b * S*) threshold = scale_factor / seqlen Key insight: Using all individual data points (N_thresholds × N_samples) @@ -88,20 +88,20 @@ def __init__( ] def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dict[str, Any]: - """Calibrate k and p parameters for Inverse Power model. + """Calibrate a and b parameters for Exponential model. Algorithm: 1. For each threshold λ_j in threshold_trials: - Run ALL samples, collect sparsities S_ij for each sample i - Compute scale_factor_ij = λ_j × L_i (where L_i is sample length) - 2. Fit Inverse Power model to ALL (sf_ij, S_ij) pairs: - scale_factor = k / (1 - sparsity)^p + 2. Fit Exponential model to ALL (sf_ij, S_ij) pairs: + scale_factor = a * exp(b * sparsity) - 3. Return fitted k and p parameters + 3. Return fitted a and b parameters At inference time (user specifies target_sparsity S*): - scale_factor = k / (1 - S*)^p + scale_factor = a * exp(b * S*) threshold = scale_factor / seqlen Args: @@ -110,7 +110,7 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic phase: Phase to calibrate ('prefill' or 'decode') Returns: - Dict with calibration results including k, p, r_squared, and num_data_points + Dict with calibration results including a, b, r_squared, and num_data_points """ # Extract attention modules attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] @@ -118,7 +118,7 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic if not attention_modules: raise ValueError("No sparse attention modules found for calibration") - print(f"Starting Inverse Power model calibration ({phase} phase)") + print(f"Starting Exponential model calibration ({phase} phase)") print(f"Threshold trials: {len(self.threshold_trials)}") # Stage 1: Collect ALL (scale_factor, sparsity) pairs for all thresholds and samples @@ -162,15 +162,16 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic print(f"Collected {len(all_data_points)} individual (scale_factor, sparsity) pairs") - # Stage 2: Fit Inverse Power model: scale_factor = k / (1 - sparsity)^p - print("\nStage 2: Fitting Inverse Power model to all data points...") + # Stage 2: Fit Exponential model: scale_factor = a * exp(b * sparsity) + print("\nStage 2: Fitting Exponential model to all data points...") # Extract data for fitting - scale_factors = np.array([p["scale_factor"] for p in all_data_points]) - sparsities = np.array([p["sparsity"] for p in all_data_points]) + scale_factors = np.array([pt["scale_factor"] for pt in all_data_points]) + sparsities = np.array([pt["sparsity"] for pt in all_data_points]) - # Filter out invalid sparsities (must be in (0, 1)) - valid_mask = (sparsities > 0.01) & (sparsities < 0.99) + # Filter out extreme sparsities (must be in (10%, 90%)) + # Extreme values are unreliable for fitting + valid_mask = (sparsities >= 0.10) & (sparsities <= 0.90) scale_factors = scale_factors[valid_mask] sparsities = sparsities[valid_mask] @@ -180,36 +181,38 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic ) return {} - # Define Inverse Power model: sf = k / (1 - S)^p - def inverse_power(sparsity, k, p): - return k / np.power(1 - sparsity, p) + # Define Exponential model: sf = a * exp(b * S) + def exponential(sparsity, a, b): + return a * np.exp(b * sparsity) # Fit the model try: popt, pcov = curve_fit( - inverse_power, + exponential, sparsities, scale_factors, - p0=[100, 1.5], # Initial guess - bounds=([0.1, 0.1], [1e7, 10]), # Bounds for k and p + p0=[1.0, 5.0], # Initial guess + bounds=([0.0, 0.0], [np.inf, 20.0]), # Bounds for a and b maxfev=10000, ) - k, p = popt + a, b = popt except Exception as e: warnings.warn(f"Curve fitting failed: {e}") return {} - # Calculate R-squared - pred_scale_factors = inverse_power(sparsities, k, p) + # Calculate R-squared and RMSE + pred_scale_factors = exponential(sparsities, a, b) ss_res = np.sum((scale_factors - pred_scale_factors) ** 2) ss_tot = np.sum((scale_factors - np.mean(scale_factors)) ** 2) r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + rmse = np.sqrt(np.mean((scale_factors - pred_scale_factors) ** 2)) - print(f"\n{phase.capitalize()} Calibration Results (Inverse Power Model):") - print(" Model: scale_factor = k / (1 - sparsity)^p") - print(f" Fitted k: {k:.4f}") - print(f" Fitted p: {p:.4f}") + print(f"\n{phase.capitalize()} Calibration Results (Exponential Model):") + print(" Model: scale_factor = a * exp(b * sparsity)") + print(f" Fitted a: {a:.6f}") + print(f" Fitted b: {b:.4f}") print(f" R-squared: {r_squared:.6f}") + print(f" RMSE: {rmse:.2f}") print(f" Data points used: {int(np.sum(valid_mask))} / {len(all_data_points)}") # Show scale_factor for various target sparsities @@ -217,7 +220,7 @@ def inverse_power(sparsity, k, p): print(f" {'Target':<10} {'Scale Factor':<15}") print(f" {'-' * 10} {'-' * 15}") for target in [0.5, 0.7, 0.8, 0.9, 0.95]: - sf = k / (1 - target) ** p + sf = a * np.exp(b * target) print(f" {target:<10.0%} {sf:<15.2f}") # Print calibration data summary by threshold @@ -238,12 +241,13 @@ def inverse_power(sparsity, k, p): return { "phase": phase, - "k": float(k), - "p": float(p), + "a": float(a), + "b": float(b), "r_squared": float(r_squared), + "rmse": float(rmse), "num_data_points": int(np.sum(valid_mask)), "total_samples": len(all_data_points), - "calibration_type": "inverse_power", + "calibration_type": "exponential", } def _enable_calibration_mode(self, modules: list[nn.Module]): diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py index 741b621f5..87da4b76c 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -461,8 +461,8 @@ def find_optimal_haystack_size( upper_bound = max(estimated_max, incremental * 2) optimal_num_haystack = None - logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") - logger.info(f"Binary search bounds: {lower_bound} to {upper_bound}") + logger.debug(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") + logger.debug(f"Binary search bounds: {lower_bound} to {upper_bound}") while lower_bound <= upper_bound: mid = (lower_bound + upper_bound) // 2 @@ -486,6 +486,6 @@ def find_optimal_haystack_size( upper_bound = mid - 1 final_size = optimal_num_haystack if optimal_num_haystack is not None else incremental - logger.info(f"Optimal haystack size: {final_size}") + logger.debug(f"Optimal haystack size: {final_size}") return final_size diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 12b343efe..754ee8765 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -147,10 +147,10 @@ def validate_threshold(cls, v): class CalibrationConfig(ModeloptBaseConfig): """Configuration for automatic threshold calibration using RULER dataset. - Calibration fits an Inverse Power model to determine dynamic thresholds that - achieve target sparsity. The model learns parameters k and p per phase: + Calibration fits an Exponential model to determine dynamic thresholds that + achieve target sparsity. The model learns parameters a and b per phase: - scale_factor = k / (1 - target_sparsity)^p + scale_factor = a * exp(b * target_sparsity) At inference time, the threshold is computed as: @@ -160,6 +160,7 @@ class CalibrationConfig(ModeloptBaseConfig): - Target sparsity can be changed at runtime without recalibration - Threshold automatically adapts to sequence length - Supports independent prefill and decode phase calibration + - Exponential model provides better fit (lower RMSE) """ target_sparse_ratio: dict[str, float] = ModeloptField( diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 716149bc9..29317c27f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -244,28 +244,85 @@ def update_sparse_attention_metadata( def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: """Extract sparse attention config for export to config.json. - Extracts the calibration parameters (k, p) and target_sparse_ratio from the first - sparse attention module that has calibrated thresholds. + Extracts the calibration parameters (a, b) for the exponential threshold model + from the first sparse attention module that has calibrated thresholds. + + The exported config allows computing threshold at runtime: + scale_factor = a * exp(b * target_sparsity) + threshold = scale_factor / seqlen Args: model: Model with sparse attention applied Returns: - Dictionary with sparse attention config, or None if no calibrated config found. - Contains "calibration_params" with k and p per phase, and "target_sparse_ratio". + Dictionary with sparse attention config for HuggingFace config.json export. + Returns None if no calibrated sparse attention modules found. + + Example output:: + + { + "config_groups": { + "group_0": {"sparse_algo": "softmax_skip", "targets": ["LlamaAttention"]} + }, + "threshold_scale_factor": { + "formula": "a * exp(b * target_sparsity)", + "prefill": {"a": 7.93, "b": 8.61}, + "decode": {"a": 0.12, "b": 9.85}, + }, + "producer": {"name": "modelopt", "version": "0.37.0"}, + } """ + import modelopt + + # Collect sparse attention module info + calibration_params = None + target_classes: set[str] = set() + for module in model.modules(): if isinstance(module, SparseAttentionModule): - calibration_params = getattr(module._sparse_method_instance, "calibration_params", None) - target_sparse_ratio = getattr( - module._sparse_method_instance, "target_sparse_ratio", None - ) - if calibration_params is not None: - return { - "calibration_params": calibration_params, - "target_sparse_ratio": target_sparse_ratio, - } - return None + # Get the original wrapped module's class name + if hasattr(module, "get_original_cls_by_level"): + original_cls = module.get_original_cls_by_level(level=0) + if original_cls is not None: + target_classes.add(original_cls.__name__) + + # Get calibration params from first module that has them + if calibration_params is None: + calibration_params = getattr( + module._sparse_method_instance, "calibration_params", None + ) + + # Return None if no calibration params found + if calibration_params is None: + return None + + # Build threshold_scale_factor with model parameters + threshold_scale_factor: dict[str, Any] = { + "formula": "a * exp(b * target_sparsity)", + } + for phase in ["prefill", "decode"]: + if phase in calibration_params: + threshold_scale_factor[phase] = { + "a": calibration_params[phase]["a"], + "b": calibration_params[phase]["b"], + } + + # Build the export config + export_config: dict[str, Any] = { + "config_groups": { + "group_0": { + "sparse_algo": "softmax_skip", + "targets": sorted(target_classes) if target_classes else ["Attention"], + } + }, + "threshold_scale_factor": threshold_scale_factor, + "producer": { + "name": "modelopt", + "version": modelopt.__version__, + }, + } + + return export_config def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): @@ -332,15 +389,15 @@ def _format_threshold(info: dict) -> str: """Format threshold info for display.""" t = info.get("type") if t == "dynamic_calibrated": - # Inverse Power model: threshold = k / (1 - sparsity)^p / seqlen + # Exponential model: threshold = a * exp(b * sparsity) / seqlen params = info.get("calibration_params", {}) target = info.get("target_sparse_ratio", {}) parts = [] for phase in ["prefill", "decode"]: if phase in params: - k, p = params[phase]["k"], params[phase]["p"] + a, b = params[phase]["a"], params[phase]["b"] s = target.get(phase, 0.5) - parts.append(f"{phase}: k={k:.1f}, p={p:.2f}, target={s:.0%}") + parts.append(f"{phase}: a={a:.4f}, b={b:.2f}, target={s:.0%}") return f"calibrated({', '.join(parts)})" if t == "static": v = info.get("value") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index c7e168abc..7cd9799e8 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -141,12 +141,12 @@ def calc_correction_factor_and_p( and phase in calibration_params and target_sparse_ratio is not None ): - # Use calibrated k, p to compute dynamic threshold - # scale_factor = k / (1 - target_sparsity)^p - k = calibration_params[phase]["k"] - p = calibration_params[phase]["p"] + # Use calibrated a, b to compute dynamic threshold + # Exponential model: scale_factor = a * exp(b * target_sparsity) + a = calibration_params[phase]["a"] + b = calibration_params[phase]["b"] target_sparsity = target_sparse_ratio.get(phase, 0.5) - scale_factor = k / ((1 - target_sparsity) ** p) + scale_factor = a * np.exp(b * target_sparsity) log_threshold = np.log(scale_factor / seq_k) else: # Use static threshold from config (no calibration or phase not calibrated) @@ -323,16 +323,16 @@ def get_threshold_info(self) -> dict[str, Any]: target_sparse_ratio = getattr(self, "target_sparse_ratio", None) if calibration_params is not None and target_sparse_ratio is not None: - # Per-phase calibrated dynamic threshold using Inverse Power model + # Per-phase calibrated dynamic threshold using Exponential model example_lengths = [1024, 4096, 16384, 65536, 131072] phase_info = {} for phase, params in calibration_params.items(): - k, p = params["k"], params["p"] + a, b = params["a"], params["b"] target_sparsity = target_sparse_ratio.get(phase, 0.5) - scale_factor = k / ((1 - target_sparsity) ** p) + scale_factor = a * np.exp(b * target_sparsity) phase_info[phase] = { - "k": k, - "p": p, + "a": a, + "b": b, "target_sparsity": target_sparsity, "scale_factor": scale_factor, "example_thresholds": { @@ -341,7 +341,7 @@ def get_threshold_info(self) -> dict[str, Any]: } return { "type": "dynamic_calibrated", - "formula": "threshold = k / (1 - target_sparsity)^p / seqlen", + "formula": "threshold = a * exp(b * target_sparsity) / seqlen", "calibration_params": calibration_params, "target_sparse_ratio": target_sparse_ratio, "phases": phase_info, diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py index dd7c4c6f5..02a305b5d 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -143,14 +143,14 @@ def forward_loop(model): sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] assert len(sparse_modules) > 0 - # Verify calibration was applied (Inverse Power model params) + # Verify calibration was applied (Exponential model params) for module in sparse_modules: method = module._sparse_method_instance - # Check if calibration params (k, p) are set + # Check if calibration params (a, b) are set if hasattr(method, "calibration_params") and method.calibration_params: for params in method.calibration_params.values(): - assert "k" in params and params["k"] > 0 - assert "p" in params and params["p"] > 0 + assert "a" in params and params["a"] > 0 + assert "b" in params and params["b"] > 0 def test_calibration_pytorch_backend(self, simple_model): """Test calibration with pytorch backend.""" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py index c093356bb..985c64244 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -677,7 +677,7 @@ def test_calibrator_default_threshold_trials_values(self): """Test that default threshold trials have expected values.""" calibrator = DynamicThresholdCalibrator() - # Should have 20 default trials (expanded range for Inverse Power model) + # Should have 20 default trials (expanded range for Exponential model) assert len(calibrator.threshold_trials) == 20 # Check specific values diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py index b02056480..fa1562ca6 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -61,25 +61,25 @@ def test_dynamic_calibrated_threshold(self): } ) - # Simulate calibration setting k and p parameters + # Simulate calibration setting a and b parameters method.calibration_params = { - "prefill": {"k": 150.0, "p": 1.5}, - "decode": {"k": 200.0, "p": 1.8}, + "prefill": {"a": 150.0, "b": 1.5}, + "decode": {"a": 200.0, "b": 1.8}, } method.target_sparse_ratio = {"prefill": 0.9, "decode": 0.9} info = method.get_threshold_info() assert info["type"] == "dynamic_calibrated" - assert info["formula"] == "threshold = k / (1 - target_sparsity)^p / seqlen" + assert info["formula"] == "threshold = a * exp(b * target_sparsity) / seqlen" assert "calibration_params" in info assert "target_sparse_ratio" in info assert "phases" in info assert "prefill" in info["phases"] assert "decode" in info["phases"] - # Check that k and p are in phase info - assert info["phases"]["prefill"]["k"] == 150.0 - assert info["phases"]["prefill"]["p"] == 1.5 + # Check that a and b are in phase info + assert info["phases"]["prefill"]["a"] == 150.0 + assert info["phases"]["prefill"]["b"] == 1.5 assert info["phases"]["prefill"]["target_sparsity"] == 0.9 def test_threshold_info_structure(self): @@ -156,13 +156,13 @@ def test_module_with_calibrated_threshold(self): sparse_model = sparsify(model, config) - # Find module and set calibrated params (Inverse Power model) + # Find module and set calibrated params (Exponential model) module = None for module in sparse_model.modules(): if isinstance(module, SparseAttentionModule): module._sparse_method_instance.calibration_params = { - "prefill": {"k": 150.0, "p": 1.5}, - "decode": {"k": 200.0, "p": 1.8}, + "prefill": {"a": 150.0, "b": 1.5}, + "decode": {"a": 200.0, "b": 1.8}, } module._sparse_method_instance.target_sparse_ratio = { "prefill": 0.9, @@ -175,7 +175,7 @@ def test_module_with_calibrated_threshold(self): info = module.get_threshold_info() assert info["type"] == "dynamic_calibrated" - assert info["calibration_params"]["prefill"]["k"] == 150.0 + assert info["calibration_params"]["prefill"]["a"] == 150.0 def test_module_without_method_instance(self): """Test get_threshold_info when sparse method instance doesn't exist.""" @@ -261,12 +261,12 @@ def test_summary_displays_dynamic_threshold(self, capsys): sparse_model = sparsify(model, config) - # Set calibrated params (Inverse Power model) + # Set calibrated params (Exponential model) for module in sparse_model.modules(): if isinstance(module, SparseAttentionModule): module._sparse_method_instance.calibration_params = { - "prefill": {"k": 150.0, "p": 1.5}, - "decode": {"k": 200.0, "p": 1.8}, + "prefill": {"a": 150.0, "b": 1.5}, + "decode": {"a": 200.0, "b": 1.8}, } module._sparse_method_instance.target_sparse_ratio = { "prefill": 0.9,