Skip to content

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Nov 11, 2025

What does this PR do?

Type of change: ?
new feature

Overview: ?

  • This PR adds the sparse attention calibration algorithm
  • Chunked prefill to support long ctx_len
  • Separated calibration for prefill and decode

Usage

import modelopt.torch.sparsity.attention_sparsity as mtsa

# Apply sparse attention with calibration
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB)

# Print summary - now shows actual thresholds
mtsa.print_sparse_attention_summary(model)
# Output:
# Method: flash_skip_softmax, Threshold: Dynamic (λ=437.395926)

# Or llm_eval integration
# HuggingFace sparse attention example
python examples/llm_sparsity/attention_sparsity/hf_sa.py \
    --pyt_ckpt_path Qwen/Qwen3-4B \
    --sparse_attn skip_softmax_calib 

The calibration method

Calibration Algorithm

  • Implemented the Inverse Power model: scale_factor = k / (1 - sparsity)^p
  • Fit model parameters (k, p) per phase using scipy.optimize.curve_fit
  • At inference: threshold = k / (1 - target_sparsity)^p / seqlen

Why Choosing the Inverse Power model?

The inverse power model better fits the relationship between sparsity ratio and threshold_scale_factor.
sparsity_model_analysis

Runtime Flexibility

  • Target sparsity can be changed at inference time without recalibration
  • Users can adjust module._sparse_method_instance.target_sparse_ratio dynamically
  • Threshold automatically adapts to sequence length

Testing

The calibration results for Qwen/Qwen3-30B-A3B-Thinking-2507 are shown below and are mostly consistent with the ground-truth numbers collected from the kernel side.

Prefill Calibration Results:
  Model: scale_factor = k / (1 - sparsity)^p
  Fitted k: 1003.3990
  Fitted p: 1.2589
  R-squared: 0.827549

Scale factors for different target sparsities:
  Target     Scale Factor
  ---------- ---------------
  50%        2401.35
  70%        4568.26
  80%        7610.98
  90%        18214.70
  95%        43591.65

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

@kaix-nv kaix-nv requested review from a team as code owners November 11, 2025 22:38
@kaix-nv kaix-nv requested review from RalphMao and removed request for RalphMao November 11, 2025 22:38
@codecov
Copy link

codecov bot commented Nov 11, 2025

Codecov Report

❌ Patch coverage is 69.23077% with 264 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.25%. Comparing base (02c5f29) to head (0553ec6).

Files with missing lines Patch % Lines
...arsity/attention_sparsity/calibration/calibrate.py 26.58% 127 Missing ⚠️
...rsity/attention_sparsity/calibration/calibrator.py 57.60% 53 Missing ⚠️
...sity/attention_sparsity/calibration/ruler_utils.py 73.88% 41 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 58.92% 23 Missing ⚠️
...ch/sparsity/attention_sparsity/sparse_attention.py 63.15% 7 Missing ⚠️
...delopt/torch/sparsity/attention_sparsity/config.py 90.62% 6 Missing ⚠️
...y/attention_sparsity/methods/flash_skip_softmax.py 90.00% 5 Missing ⚠️
...sparsity/attention_sparsity/calibration/dataset.py 99.39% 1 Missing ⚠️
...ch/sparsity/attention_sparsity/methods/registry.py 80.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #538      +/-   ##
==========================================
- Coverage   73.38%   73.25%   -0.13%     
==========================================
  Files         193      199       +6     
  Lines       19893    20713     +820     
==========================================
+ Hits        14598    15174     +576     
- Misses       5295     5539     +244     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 8c7ee86 to da6f627 Compare November 12, 2025 00:17
@kaix-nv kaix-nv changed the title [3/n] Adds sparse attention integration to the llm_eval examples [OMNIML-2850] [3/n] Adds sparse attention integration to the llm_eval examples Nov 12, 2025
@kaix-nv kaix-nv changed the title [OMNIML-2850] [3/n] Adds sparse attention integration to the llm_eval examples [OMNIML-2850][3/n] Adds sparse attention integration to the llm_eval examples Nov 12, 2025
@kaix-nv kaix-nv changed the title [OMNIML-2850][3/n] Adds sparse attention integration to the llm_eval examples [OMNIML-2850] [3/n] Adds sparse attention integration to the llm_eval examples Nov 12, 2025
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 4 times, most recently from 525a119 to c9d7008 Compare November 13, 2025 07:40
@kaix-nv kaix-nv changed the title [OMNIML-2850] [3/n] Adds sparse attention integration to the llm_eval examples [OMNIML-2850] [3/n] Adds sparse attention calibration; Adds llm_eval support Nov 14, 2025
@kaix-nv kaix-nv changed the title [OMNIML-2850] [3/n] Adds sparse attention calibration; Adds llm_eval support [OMNIML-2850] [3/n] Adds sparse attention calibration Nov 14, 2025
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 5 times, most recently from 7727793 to 2864629 Compare December 1, 2025 11:35
@kaix-nv kaix-nv requested a review from a team as a code owner December 1, 2025 11:35
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 2864629 to ca7e24e Compare December 1, 2025 15:19
@kaix-nv kaix-nv removed the request for review from kevalmorabia97 December 1, 2025 15:25
@kevalmorabia97
Copy link
Collaborator

kevalmorabia97 commented Dec 1, 2025

@kaix-nv github is showing 7000+ lines of code as part of this PR. Is that accurate?
It shouldn’t be that much. Less than half of the code should remain after rebasing on the preceding PR.

@kaix-nv kaix-nv requested a review from jy-yuan December 8, 2025 21:52
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 4 times, most recently from 3474b6f to 74a29ea Compare December 13, 2025 21:00
Comment on lines 272 to 279
# Force eager attention if sparse attention is requested
if sparse_cfg:
kwargs["attn_implementation"] = "eager"
warnings.warn(
"Sparse attention requires attn_implementation='eager'. "
"Forcing eager attention implementation."
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can e move this to sparsity/plugins/hugginface ? We should detect if this is a HF model and if yes apply this (see

AutoQuantizeGradientSearcher.register_custom_support(
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved this check to hugginface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as https://github.com/NVIDIA/Model-Optimizer/pull/538/files#r2646356349 and avoid repeated attention modification

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed this check.

from .dataset import RulerDatasetBuilder


def _extract_tokenizer_from_model(model: nn.Module) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works only work Huggingface transformers, is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Megatron support (e.g., adding forward_loop) will be added in future PRs.

Copy link
Contributor

@realAsma realAsma Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this sparsity/plugins?

Otherwise this change will make transformers library a required dependency of ModelOpt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use local imports of third party libraries wherever necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve switched to local imports. Since most of the dataset construction logic can be reused by Megatron, I’d prefer to keep this file under the calibration folder.

# For causal attention, only count lower triangle blocks (including diagonal)
num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2
total_valid_blocks = batch_size * num_heads * num_causal_blocks
density = float(block_mask.sum()) / total_valid_blocks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep this as a torch tensor? a float(tensor_in_gpu) causes unneseccary CPU-GPU sync

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. It’s been fixed.

density = float(block_mask.sum()) / total_valid_blocks
total_blocks = num_causal_blocks
else:
density = float(block_mask.sum() / block_mask.numel())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s been fixed.

Copy link
Collaborator

@jy-yuan jy-yuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the PR, and the calibration logic looks correct (dynamic thresholding via regression and phase separation seems solid). I tested the code and verified that the unit tests pass locally.

One small observation regarding dependencies: I noticed I had to manually install nltk and wonderwords to run the tests/calibration. It seems they are currently added to dev-test in setup.py, so they aren't included in pip install nvidia-modelopt[all]. If the RULER-based calibration is intended to be a supported feature for users (i.e., when they don't provide a custom forward loop), we might want to consider moving these to a user-facing optional dependency group (like calibration or hf) or catching the ModuleNotFoundError to suggest installation.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 2 times, most recently from c1692d4 to e65d3e1 Compare December 30, 2025 22:32
@kaix-nv
Copy link
Contributor Author

kaix-nv commented Dec 30, 2025

Reviewed the PR, and the calibration logic looks correct (dynamic thresholding via regression and phase separation seems solid). I tested the code and verified that the unit tests pass locally.

One small observation regarding dependencies: I noticed I had to manually install nltk and wonderwords to run the tests/calibration. It seems they are currently added to dev-test in setup.py, so they aren't included in pip install nvidia-modelopt[all]. If the RULER-based calibration is intended to be a supported feature for users (i.e., when they don't provide a custom forward loop), we might want to consider moving these to a user-facing optional dependency group (like calibration or hf) or catching the ModuleNotFoundError to suggest installation.

Thanks Jiayi. Good catch. I’ve moved the dependency to the hf group, so it can now be installed via pip install -U nvidia-modelopt[hf]. cc @kevalmorabia97

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 3 times, most recently from 30a9794 to ed213d9 Compare December 31, 2025 08:02
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 130f1e9 to a3dcd9d Compare January 20, 2026 06:46
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from a3dcd9d to 33d0025 Compare January 28, 2026 00:28
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 28, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

  • 🔍 Trigger a full review
📝 Walkthrough

Walkthrough

This pull request introduces comprehensive sparse attention calibration support for PyTorch LLMs. It adds a calibration framework with dynamic threshold computation using RULER datasets, updates sparse attention configuration to support per-phase thresholds (prefill/decode), integrates statistics collection into sparse modules, and extends example workflows with sparse attention options.

Changes

Cohort / File(s) Summary
Example Workflows
examples/llm_eval/lm_eval_hf.py, examples/llm_eval/mmlu.py, examples/llm_eval/modeling.py
Added sparse attention support with sparse_cfg parameter handling. Extended SeqToSeqModel with optional attn_implementation field for propagation to model loading. Sparse attention applied post-quantization with duplicate-check warnings.
Sparse Attention Utilities & Examples
examples/llm_eval/sparse_attention_utils.py, examples/llm_sparsity/attention_sparsity/hf_sa.py, examples/llm_sparsity/attention_sparsity/...
New utility module for sparsifying models with string/dict-based configs and backend override. Refactored HF sparsity example from dataset-based to test-prompt-based workflow; added --target_sparse_ratio CLI option; replaced verification flow with pre/post-sparse output generation. Added documentation and calibration data download script.
Calibration Infrastructure - Core
modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py, .../calibrate.py, .../calibrator.py, .../dataset.py, .../ruler_utils.py
New calibration module with entry point re-exports. calibrate.py orchestrates calibration: manages caching, builds/loads calibration data, runs prefill/decode phases via DynamicThresholdCalibrator. calibrator.py fits inverse power model to threshold-sparsity pairs, extracts per-sample stats, computes k and p parameters. dataset.py provides RulerDatasetBuilder with synthetic RULER task generation (NIAH, VT, FWE, QA). ruler_utils.py implements RULER helpers (Paul Graham essays, word lists, random generation, NIAH sample construction).
Configuration Updates
modelopt/torch/sparsity/attention_sparsity/config.py
Changed threshold from float to dict[str, float] with per-phase keys (prefill/decode). Added collect_stats field to SparseAttentionAttributeConfig. Introduced new CalibrationConfig class with target sparsity, sample counts, sequence lengths, and threshold trials. Added presets SKIP_SOFTMAX_DEFAULT and SKIP_SOFTMAX_CALIB. Updated validators and expanded __all__ exports.
Sparse Attention Execution & Stats
modelopt/torch/sparsity/attention_sparsity/sparse_attention.py, .../stats_manager.py, .../methods/flash_skip_softmax.py, .../methods/registry.py
Enhanced SparseAttentionModule with optional SparseAttentionStatsManager for per-call statistics collection. New get_threshold_info() method exposes threshold metadata. Updated _create_sparse_softmax to separate sparsity calculation (stats) from mask application (calibration-mode aware). FlashSkipSoftmax extended with dict-based threshold handling, calibration mode, calculate_sparsity(), and dynamic threshold logic. Base SparseAttentionMethod updated with abstract methods. New SparseAttentionStatsManager class tracks aggregated and per-sample sparsity, supports phase-based filtering.
Conversion & Export
modelopt/torch/sparsity/attention_sparsity/conversion.py, .../model_sparsify.py, .../plugins/__init__.py, .../plugins/huggingface.py, modelopt/torch/export/unified_export_hf.py
Added export_sparse_attention_config() to extract calibration params, print_sparse_attention_summary() for module-level sparsity reporting, and _format_threshold() helper. New plugin registry system (CUSTOM_MODEL_PLUGINS, register_custom_model_plugins_on_the_fly) replaces direct registration. HuggingFace plugin enforces eager attention via validate_eager_attention(). New calibrate() function in model_sparsify.py bridges sparsification and calibration. Unified export attaches sparse_attention_config to checkpoints when available.
Tests - Configuration & Methods
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py, .../test_flash_skip_softmax.py, .../test_threshold_info.py, .../test_sparse_attention_conversion.py
New config validation tests cover threshold dict format, phase keys, and presets. Updated flash-skip-softmax tests to use dict-based thresholds; added tests for calculate_sparsity() and apply_sparsity() with optional masks. Comprehensive threshold info tests validate static/dynamic threshold reporting and summary formatting. Conversion tests validate new methods and threshold formats.
Tests - Calibration
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py, .../test_stats_manager.py, tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py, .../test_integration_gpu.py
New unit tests for calibration config, RULER dataset generation, DynamicThresholdCalibrator regression logic, and calibrate_sparse_attention() orchestration. Stats manager unit tests cover initialization, collection, phase handling, calibration mode, summary generation, and reset. GPU integration tests exercise end-to-end calibration with model persistence, inference validation, and memory monitoring; compare calibrated vs fixed-threshold behavior.
Tests - LLM Eval & Common
tests/examples/llm_eval/test_llm_eval.py, tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py, tests/_test_utils/torch/sparsity/sparse_attention_common.py, tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
New test for sparse-attention-only LLM evaluation. Updated test config to use dict-based threshold, extended save_restore_test() with atol parameter for tolerance control, and updated assertions with batch shape diagnostics. Integration tests updated with dict-based thresholds across prefill/decode phases.
Dependencies & License
setup.py, examples/llm_sparsity/weight_sparsity/finetune.py, examples/llm_sparsity/attention_sparsity/.gitignore, .../README.md, .../download_ruler_data.sh
Added optional runtime dependencies (nltk, wonderwords) for HF and dev-test extras. Updated license header in finetune.py. Added .gitignore rule for calibration data directory. New README documenting sparse attention workflow, calibration data setup, and CLI usage. New bash script for RULER essay data download.

Sequence Diagram(s)

sequenceDiagram
    participant User as User/CLI
    participant ModelOpt as ModelOpt Sparsity
    participant DataBuilder as RULER DatasetBuilder
    participant Calibrator as DynamicThresholdCalibrator
    participant Module as SparseAttentionModule
    participant StatsManager as StatsManager

    User->>ModelOpt: sparsify(model, config with calibration)
    ModelOpt->>Module: apply sparse attention
    ModelOpt->>ModelOpt: extract calibration config
    
    alt Calibration enabled
        ModelOpt->>DataBuilder: build_calibration_dataset()
        DataBuilder-->>ModelOpt: dataset samples (varying lengths)
        
        ModelOpt->>Calibrator: calibrate(model, forward_loop, phase="prefill")
        Calibrator->>Module: set_calibration_mode(True)
        Calibrator->>StatsManager: collect(stats)
        StatsManager-->>Calibrator: per-sample sparsity data
        Calibrator->>Calibrator: fit inverse power model (k, p)
        Calibrator-->>ModelOpt: calibration_results (prefill)
        
        ModelOpt->>Calibrator: calibrate(model, forward_loop, phase="decode")
        Calibrator->>Module: set_calibration_mode(True)
        Calibrator->>StatsManager: collect(stats)
        StatsManager-->>Calibrator: per-sample sparsity data
        Calibrator->>Calibrator: fit inverse power model (k, p)
        Calibrator-->>ModelOpt: calibration_results (decode)
        
        ModelOpt->>Module: apply calibrated thresholds
        Module->>Module: set threshold_scale_factor (dynamic)
    end
    
    ModelOpt-->>User: calibrated model with sparse attention
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~110 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title clearly describes the main feature: adding sparse attention calibration, which is the primary change across the files. The title is specific, concise, and accurately represents the core objective.
Docstring Coverage ✅ Passed Docstring coverage is 92.14% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kaix/sparse_attention_calibration

Tip

🧪 Unit Test Generation v2 is now available!

We have significantly improved our unit test generation capabilities.

To enable: Add this to your .coderabbit.yaml configuration:

reviews:
  finishing_touches:
    unit_tests:
      enabled: true

Try it out by using the @coderabbitai generate unit tests command on your code files or under ✨ Finishing Touches on the walkthrough!

Have feedback? Share your thoughts on our Discord thread!


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 18

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
examples/llm_eval/modeling.py (1)

175-194: Add error handling for unsupported attention implementations.

Lines 192–193 and 247–248 pass attn_implementation to from_pretrained()—supported since Transformers v4.36.0 (project requires ≥4.55.0). However, not all model architectures support all attention backends. If a model doesn't support the requested backend ("eager", "sdpa", or "flash_attention_2"), from_pretrained() may error or silently fall back. Add a try/except with a clear error message indicating which attention implementation is unsupported for the given model.

examples/llm_eval/lm_eval_hf.py (1)

76-106: Early return on pre-quantized models skips sparse attention.

If the model is already quantized (line 76-79), the function returns immediately, bypassing the sparse attention application (lines 99-106). This means users cannot apply sparse attention to pre-quantized models via this path.

If this is intentional, consider documenting this behavior. If sparse attention should still be applicable to pre-quantized models, the early return should be restructured.

🔧 Suggested fix to allow sparse attention on pre-quantized models
     if is_quantized(model_obj.model):
         # return if model is already quantized
         warnings.warn("Skipping quantization: model is already quantized.")
-        return model_obj
+    elif quant_cfg:
-
-    if quant_cfg:
         if not calib_batch_size:
             calib_batch_size = model_obj.batch_size

         quantize_model(
             model=model_obj,
             quant_cfg=quant_cfg.split(",") if auto_quantize_bits is not None else quant_cfg,
             tokenizer=model_obj.tokenizer,
             batch_size=calib_batch_size,
             calib_size=calib_size,
             auto_quantize_bits=auto_quantize_bits,
             auto_quantize_method=auto_quantize_method,
             auto_quantize_score_size=auto_quantize_score_size,
             test_generated=False,
             compress=compress,
             auto_quantize_checkpoint=auto_quantize_checkpoint,
         )

     if sparse_cfg:
         if is_attn_sparsified(model_obj.model):
             warnings.warn("Skipping sparse attention: model already has sparse attention applied.")
         else:
             sparsify_model(
                 model=model_obj,
                 sparse_cfg=sparse_cfg,
             )

     return model_obj
tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py (1)

143-144: Tighten sparsity assertion to 0 <= stats["sparsity"] <= 1.

The formula sparsity = 1.0 - dense_blocks.item() / total_valid_blocks cannot produce negative values since dense_blocks (a sum of boolean mask) is always between 0 and total_valid_blocks. The current assertion allowing -1 <= sparsity <= 1 is overly permissive and the comment about negative sparsity is incorrect. Change to assert 0 <= stats["sparsity"] <= 1.

examples/llm_sparsity/attention_sparsity/hf_sa.py (1)

93-134: Avoid undefined generated_text when CUDA is unavailable.

generate_sample_output only sets generated_text inside the CUDA branch; if the helper is reused in a CPU-only context, it will raise UnboundLocalError. Make the GPU requirement explicit (or remove the inner CUDA guard since main already enforces it).

🛠 Suggested fix
 def generate_sample_output(model, tokenizer, args):
-    # Load test sample
+    if not torch.cuda.is_available():
+        raise OSError("GPU is required for inference.")
+    # Load test sample
@@
-    if torch.cuda.is_available():
-        inputs = {k: v.cuda() for k, v in inputs.items()}
-
-        # Generate
-        with torch.no_grad():
-            outputs = model.generate(
-                **inputs,
-                max_new_tokens=args.max_new_tokens,
-                do_sample=args.do_sample,
-                temperature=args.temperature if args.do_sample else 1.0,
-                pad_token_id=tokenizer.pad_token_id,
-            )
-            input_length = inputs["input_ids"].shape[1]
-            generated_ids = outputs[0][input_length:]
-        generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
+    inputs = {k: v.cuda() for k, v in inputs.items()}
+    # Generate
+    with torch.no_grad():
+        outputs = model.generate(
+            **inputs,
+            max_new_tokens=args.max_new_tokens,
+            do_sample=args.do_sample,
+            temperature=args.temperature if args.do_sample else 1.0,
+            pad_token_id=tokenizer.pad_token_id,
+        )
+        input_length = inputs["input_ids"].shape[1]
+        generated_ids = outputs[0][input_length:]
+    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
🤖 Fix all issues with AI agents
In `@examples/llm_sparsity/attention_sparsity/download_ruler_data.sh`:
- Around line 33-45: The loop currently increments count regardless of curl
success; change it so count is incremented only when curl succeeds and failures
are printed—after building raw_url and calling curl in the while loop (the block
using variables url, filename, filepath, raw_url and command curl -fsSL),
capture curl's exit status and only run count=$((count + 1)) when curl exits
successfully, otherwise print a visible failure message including the
url/filepath so failed downloads are surfaced.

In `@examples/llm_sparsity/attention_sparsity/README.md`:
- Around line 9-22: The example in README.md is missing imports required to run
the snippet; add an import for torch and for AutoModelForCausalLM from
transformers at the top so the symbols AutoModelForCausalLM and torch referenced
when loading the model (and later calling mtsa.sparsify with
SKIP_SOFTMAX_DEFAULT) are defined; update the snippet to include these two
imports before importing modelopt.torch.sparsity.attention_sparsity and using
mtsa.sparsify.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 192-257: The forward_loop in
create_decode_calibration_forward_loop calls torch.cuda.empty_cache()
unconditionally which will fail on CPU-only runs; change it to only call
empty_cache when CUDA is available (e.g., guard with torch.cuda.is_available()
or check device.type == "cuda") after deleting past_key_values so CPU-only
execution doesn't crash while preserving GPU cleanup behavior.
- Around line 135-189: The function create_calibration_forward_loop
(specifically the nested forward_loop) calls torch.cuda.empty_cache() unguarded
in both the chunked prefill branch (after deleting past_key_values) and the full
prefill branch; wrap both calls in a runtime GPU check by only calling
torch.cuda.empty_cache() when torch.cuda.is_available() to avoid AssertionError
on CPU-only PyTorch builds (i.e., add an if torch.cuda.is_available(): guard
around the existing empty_cache() invocations).

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 249-272: Enable calibration should stash the stats-manager's prior
existence and enabled state on the module so we can restore it later; modify
_enable_calibration_mode to record whether module._stats_manager existed and its
previous enabled flag (e.g., store tuple in module._prev_stats_manager_state or
similar) before creating/enabling a SparseAttentionStatsManager and calling
set_calibration_mode(True), and set a marker if you created the manager; then
modify _disable_calibration_mode to consult that stash and restore the original
state: if the manager was newly created by the calibrator, remove it (or set to
None), otherwise restore module._stats_manager.enabled to the previous value and
call set_calibration_mode(False); use the existing symbols
_enable_calibration_mode, _disable_calibration_mode, module._stats_manager,
SparseAttentionStatsManager, and set_calibration_mode to locate and implement
these changes.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py`:
- Around line 229-265: The current build_calibration_dataset logic can generate
more than self.total_samples because samples_per_task = max(num_samples //
len(self.subtasks), 1) forces at least one sample per subtask/bin; change the
allocation to cap generation so total produced does not exceed
self.total_samples by computing samples_per_task = num_samples //
len(self.subtasks) and only using the >0 fallback when there is remaining budget
across bins (or by tracking produced_count and breaking when produced_count >=
self.total_samples); update code paths in build_calibration_dataset that use
samples_per_task, the pbar update, and any loop over
self.subtasks/_generate_sample so they stop when the global produced_count
reaches self.total_samples (and ensure pbar.total equals self.total_samples).

In `@modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py`:
- Around line 102-117: The hardcoded DATA_DIR (and derived
RULER_URLS_FILE/ESSAYS_DIR) lives under the package tree which can be read-only;
make the data directory user-configurable by changing _get_data_dir to prefer an
environment variable (e.g., RULER_DATA_DIR or MODELOPT_RULER_DATA_DIR) and fall
back to a user-writable cache location (use platform-appropriate user cache/
config dir), then ensure DATA_DIR, RULER_URLS_FILE and ESSAYS_DIR are resolved
via that helper; create the directory with mkdir(parents=True, exist_ok=True)
and keep existing names for the files so callers using DATA_DIR, _get_data_dir,
RULER_URLS_FILE, and ESSAYS_DIR keep working.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 251-271: The validator validate_target_sparse_ratio currently
allows values equal to 1.0 which can cause divide-by-zero in downstream dynamic
thresholding; update the check in validate_target_sparse_ratio to require 0.0 <=
ratio < 1.0 (strictly less than 1.0) and modify the raised ValueError message to
state the allowed range is [0.0, 1.0). Ensure the error includes the phase and
invalid ratio and still enforces that target_sparse_ratio is a dict with only
"prefill" and "decode" keys.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 334-347: The _format_threshold function currently only handles
"dynamic" and "static" types, so thresholds labeled "dynamic_calibrated"
(returned by get_threshold_info()) are rendered as N/A; update _format_threshold
to detect t == "dynamic_calibrated", read per-phase objects from
info.get("phases", {}) (each phase entry contains a scale_factor key), build
parts like f"{phase}={scale_factor:.2f}" (fallback to raw value if missing), and
return the same λ={...} formatted string as for "dynamic" so calibrated
per-phase thresholds are displayed correctly.

In `@tests/_test_utils/torch/sparsity/sparse_attention_common.py`:
- Around line 195-197: The assertion incorrectly passes atol as the third
positional argument to torch.allclose (which is rtol); update the call in the
test where torch.allclose(output_sparse, output_restored, atol) is used so that
the tolerance is passed by name (e.g., torch.allclose(output_sparse,
output_restored, atol=atol)) or otherwise supply both rtol and atol positionally
in the correct order; ensure the variables output_sparse, output_restored and
atol are used so the absolute tolerance is applied as intended.

In `@tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py`:
- Around line 44-51: The docstring for test_attention_sparsity incorrectly
claims it tests calibration scenarios but the test isn't parameterized; either
add parametrization for calibration or update the docstring—here: remove the
misleading "(with and without calibration)" text from the
test_attention_sparsity docstring and also remove the redundant seq_len=128
argument from the run_attention_sparsity_command call (the helper defaults to
128); locate test_attention_sparsity and run_attention_sparsity_command in the
test file to apply these edits.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py`:
- Around line 18-28: This GPU test must skip when the optional "transformers"
package is missing because RulerDatasetBuilder loads tokenizers by name at
runtime; add a guard using pytest.importorskip("transformers") (or equivalent
try/except ImportError and pytest.skip) near the top of the test file (same area
as the existing CUDA guard) so tests referencing RulerDatasetBuilder or passing
tokenizer names like "gpt2" won't raise ImportError at runtime.

In
`@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py`:
- Around line 67-79: The tests in test_threshold_validation_range call
SparseAttentionAttributeConfig with scalar thresholds but assert the validator
raises "Threshold must be in range"; since the validator first enforces that
threshold is a dict, update the expectations to match the dict-type validation
error instead of range errors—i.e., in test_threshold_validation_range change
the pytest.raises match from "Threshold must be in range" to the message used by
the dict-type check (e.g., "Threshold must be a dict") for the scalar inputs
when constructing SparseAttentionAttributeConfig.
- Around line 99-102: The test test_threshold_validation_type should assert the
actual validator message from SparseAttentionAttributeConfig rather than a
generic Pydantic message: update the pytest.raises match to the custom validator
text "Threshold must be a dict with 'prefill' and/or 'decode' keys" (or change
the invalid input to a type that triggers the default Pydantic type error) so it
aligns with the validate_threshold validator behavior in
SparseAttentionAttributeConfig.
- Around line 34-46: test_valid_config is passing threshold as a scalar while
the validator validate_threshold on SparseAttentionAttributeConfig expects a
dict with 'prefill' and/or 'decode'; update the test to pass a dict (e.g.
threshold={"prefill": 1e-4, "decode": 1e-4} or at least one of those keys) so
the value satisfies validate_threshold and the assertions remain valid for
SparseAttentionAttributeConfig.

In `@tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py`:
- Around line 18-21: Remove the unnecessary test-level skip by deleting the
pytest.importorskip("transformers") line in the test module; the tests for
SparseAttentionStatsManager do not depend on transformers so simply remove that
import-or-skip call near the top of
tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py and keep the
rest of the file (ensuring SparseAttentionStatsManager tests run on minimal
installs).

In `@tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py`:
- Around line 52-79: The test expects the old threshold-info contract from
FlashSkipSoftmax.get_threshold_info (checking type=="dynamic",
threshold_scale_factor/scale_factors, and formula "λ[phase] / length"), but the
implementation now returns calibration_params and target_sparse_ratio with a
different type/payload; update the tests to assert the new contract: call
FlashSkipSoftmax.get_threshold_info(), assert the returned info uses the new
keys (e.g., info["type"] matches the new value, info["calibration_params"]
contains per-phase calibration data, and info["target_sparse_ratio"] or
equivalent exists), replace assertions that reference
threshold_scale_factor/scale_factors and the old formula with checks against
info["calibration_params"] structure and example thresholds derived from
calibration_params (or alternatively restore legacy fields in
FlashSkipSoftmax.get_threshold_info to emit scale_factors for backward
compatibility). Ensure references to FlashSkipSoftmax.get_threshold_info and the
instance attribute threshold_scale_factor are updated/removed across the other
affected test blocks (lines ~136-170 and ~233-269).
🧹 Nitpick comments (14)
setup.py (1)

60-67: Remove duplicate deepspeed entry in hf extras.
It appears twice with the same marker; keep a single entry to avoid redundancy and reduce maintenance noise.

♻️ Proposed cleanup
     "hf": [
         "accelerate>=1.0.0",
         "datasets>=3.0.0",
         "deepspeed>=0.9.6 ; platform_system != 'Darwin' and platform_system != 'Windows'",
         "diffusers>=0.32.2",
         "huggingface_hub>=0.24.0",
         "peft>=0.17.0",
         "transformers>=4.53,<5.0",  # Should match modelopt/torch/__init__.py and tox.ini
-        "deepspeed>=0.9.6 ; platform_system != 'Darwin' and platform_system != 'Windows'",
         "nltk",
         "wonderwords",
     ],
tests/_test_utils/torch/sparsity/sparse_attention_common.py (1)

164-171: Missing documentation for atol parameter.

The docstring lists Args but doesn't document the new atol parameter.

📝 Suggested docstring update
 def save_restore_test(model_cls, device, sparse_config, atol=1e-6):
     """Test save and restore of sparse attention state.

     Args:
         model_cls: Model class to test
         device: Device to run on ('cpu' or 'cuda')
         sparse_config: Sparse attention configuration
+        atol: Absolute tolerance for output comparison (default: 1e-6)
     """
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (1)

253-279: Consider adding a test for decode threshold retrieval.

The test validates that get_threshold_info() returns type: "static" and value: 0.005 (the prefill threshold), but it doesn't verify the decode threshold. Consider adding an assertion or separate test to ensure both phase thresholds are accessible via the threshold info API, especially since the PR introduces per-phase threshold handling.

examples/llm_eval/sparse_attention_utils.py (2)

63-71: Backend override may discard other top-level config keys.

When the backend override is applied, line 71 creates a new config dict containing only "sparse_cfg". If the original mtsa_cfg contained other top-level keys (e.g., "export_format" from SparseAttentionConfig), they would be silently discarded.

♻️ Suggested fix to preserve other config keys
-            mtsa_cfg = {"sparse_cfg": modified_sparse_cfg}
+            mtsa_cfg = {**mtsa_cfg, "sparse_cfg": modified_sparse_cfg}

21-28: Consider documenting supported wrapper types.

The _extract_model helper checks for gpt2 and model attributes to extract the underlying model. Adding a brief comment listing the known wrapper types (e.g., HFLM, EvalModel) would help maintainability.

modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)

60-76: Docstring example uses outdated code block syntax.

The docstring uses .. code-block::python (missing space before python) which may cause rendering issues in some documentation generators.

♻️ Fix code-block directive syntax
-            .. code-block::python
+            .. code-block:: python
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)

57-65: Docstring should match mask value semantics.
Implementations use torch.finfo(...).min, not literal -inf. Updating the doc avoids confusion.

✏️ Suggested doc tweak
-        Returns:
-            Masked attention scores with sparse elements set to -inf
+        Returns:
+            Masked attention scores with sparse elements set to dtype minimum
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)

244-270: Guard export against missing target_sparse_ratio.
If calibration_params exists but target_sparse_ratio is None, the exported config may be incomplete for downstream consumers.

💡 Suggested guard
-            if calibration_params is not None:
+            if calibration_params is not None and target_sparse_ratio is not None:
                 return {
                     "calibration_params": calibration_params,
                     "target_sparse_ratio": target_sparse_ratio,
                 }
modelopt/torch/sparsity/attention_sparsity/stats_manager.py (2)

71-74: Count unexpected phases as unknown to keep stats consistent.
Currently, unexpected phase values are ignored, which can make phase distributions not sum to total calls.

🧮 Suggested fallback
-        if phase in self.aggregated_stats["phase_counts"]:
-            self.aggregated_stats["phase_counts"][phase] += 1
+        if phase in self.aggregated_stats["phase_counts"]:
+            self.aggregated_stats["phase_counts"][phase] += 1
+        else:
+            self.aggregated_stats["phase_counts"]["unknown"] += 1

145-147: Avoid exposing mutable internal stats list.
Returning the internal list allows external mutation; a shallow copy is safer.

🛡️ Suggested defensive copy
-        if phase is None:
-            return self.per_sample_stats
-        return [s for s in self.per_sample_stats if s.get("phase") == phase]
+        if phase is None:
+            return list(self.per_sample_stats)
+        return [s.copy() for s in self.per_sample_stats if s.get("phase") == phase]
tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py (1)

135-140: No-op forward_loop means calibration paths aren’t exercised.
Several tests define a pass loop, which can make calibration a no-op and allow regressions to slip through. Consider at least one minimal forward to collect stats.

💡 Minimal forward loop example
-        def forward_loop(model):
-            # Simple forward loop for calibration
-            pass
+        def forward_loop(model):
+            model.eval()
+            with torch.no_grad():
+                sample = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=8).cuda()
+                model(sample)
modelopt/torch/sparsity/attention_sparsity/sparse_attention.py (1)

204-215: Guard calibration-mode check for methods that don’t define _calibration_mode.

Directly reading a private attribute risks AttributeError if another sparse method hasn’t been updated yet. A getattr(..., False) keeps this robust without changing behavior.

♻️ Suggested tweak
-            if not self._sparse_method_instance._calibration_mode:
+            if not getattr(self._sparse_method_instance, "_calibration_mode", False):
                 input = self._sparse_method_instance.apply_sparsity(input, sparse_mask)
modelopt/torch/sparsity/attention_sparsity/config.py (1)

365-406: Consider matching both "attn" and "attention" module names in presets.

"*attn*" won’t match layers that use the full "attention" name. Adding a second pattern avoids silently skipping those models in the preset configs.

♻️ Example expansion (apply similarly to SKIP_SOFTMAX_CALIB)
 SKIP_SOFTMAX_DEFAULT = {
     "sparse_cfg": {
         "*attn*": {
             "method": "flash_skip_softmax",
             "threshold": {
                 "prefill": 1e-3,  # More aggressive during prefill
                 "decode": 1e-4,  # Conservative during decode
             },
             "br": 128,  # Flash Attention block rows
             "bc": 128,  # Flash Attention block columns
             "backend": "pytorch",  # Only pytorch backend supported
             "collect_stats": True,
             "enable": True,
         },
+        "*attention*": {
+            "method": "flash_skip_softmax",
+            "threshold": {
+                "prefill": 1e-3,
+                "decode": 1e-4,
+            },
+            "br": 128,
+            "bc": 128,
+            "backend": "pytorch",
+            "collect_stats": True,
+            "enable": True,
+        },
         "default": {"enable": False},
     },
 }
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)

25-28: Guard optional transformers import to avoid import-time failures.

If transformers is an optional dependency (e.g., HF extras), importing this module will raise on base installs. Consider lazy-importing AutoTokenizer inside the forward-loop creators (or wrapping with a clear error). Please verify the dependency expectations for base installs.

🔧 Example approach
-from transformers import AutoTokenizer
+try:
+    from transformers import AutoTokenizer
+except Exception as exc:  # transformers optional
+    AutoTokenizer = None
+    _TRANSFORMERS_IMPORT_ERROR = exc
 def create_calibration_forward_loop(...):
+    if AutoTokenizer is None:
+        raise ModuleNotFoundError(
+            "transformers is required for calibration; install nvidia-modelopt[hf]."
+        ) from _TRANSFORMERS_IMPORT_ERROR
     tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
 def create_decode_calibration_forward_loop(...):
+    if AutoTokenizer is None:
+        raise ModuleNotFoundError(
+            "transformers is required for calibration; install nvidia-modelopt[hf]."
+        ) from _TRANSFORMERS_IMPORT_ERROR
     tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)

Comment on lines +33 to +45
# Download essays from GitHub URLs
echo -n "Downloading essays"
count=0
while IFS= read -r url || [ -n "$url" ]; do
if [[ "${url}" == https://github.com*.txt ]]; then
filename=$(basename "${url}")
filepath="${ESSAYS_DIR}/${filename}"
if [ ! -f "${filepath}" ]; then
raw_url="${url/github.com/raw.githubusercontent.com}"
raw_url="${raw_url/\/raw\//\/}"
curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "."
count=$((count + 1))
fi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Only count successful downloads (and surface failures).
Line 43 increments count even if curl fails, so the final count/progress can be inaccurate and failures remain silent.

🛠️ Proposed fix
-            curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "."
-            count=$((count + 1))
+            if curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null; then
+                echo -n "."
+                count=$((count + 1))
+            else
+                echo "Failed to download ${raw_url}" >&2
+            fi
🤖 Prompt for AI Agents
In `@examples/llm_sparsity/attention_sparsity/download_ruler_data.sh` around lines
33 - 45, The loop currently increments count regardless of curl success; change
it so count is incremented only when curl succeeds and failures are
printed—after building raw_url and calling curl in the while loop (the block
using variables url, filename, filepath, raw_url and command curl -fsSL),
capture curl's exit status and only run count=$((count + 1)) when curl exits
successfully, otherwise print a visible failure message including the
url/filepath so failed downloads are surfaced.

Comment on lines +9 to +22
```python
import modelopt.torch.sparsity.attention_sparsity as mtsa
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT

# Load your model
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-8B",
attn_implementation="eager", # Required for sparse attention
torch_dtype=torch.bfloat16,
)

# Apply sparse attention
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT)
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add missing imports in the quick example.

The code example is missing necessary imports for AutoModelForCausalLM and torch, which would cause the example to fail if copied as-is.

📝 Suggested fix
 ```python
+import torch
+from transformers import AutoModelForCausalLM
+
 import modelopt.torch.sparsity.attention_sparsity as mtsa
 from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT
🤖 Prompt for AI Agents
In `@examples/llm_sparsity/attention_sparsity/README.md` around lines 9 - 22, The
example in README.md is missing imports required to run the snippet; add an
import for torch and for AutoModelForCausalLM from transformers at the top so
the symbols AutoModelForCausalLM and torch referenced when loading the model
(and later calling mtsa.sparsify with SKIP_SOFTMAX_DEFAULT) are defined; update
the snippet to include these two imports before importing
modelopt.torch.sparsity.attention_sparsity and using mtsa.sparsify.

Comment on lines +135 to +187
def create_calibration_forward_loop(
calibration_data: list[dict[str, Any]],
tokenizer_name_or_path: str,
batch_size: int = 1,
chunk_size: int = 2048,
) -> Callable:
"""Create forward loop for calibration.

Args:
calibration_data: List of samples with 'input' and 'length' fields
tokenizer_name_or_path: HuggingFace tokenizer path
batch_size: Batch size (currently unused, always 1)
chunk_size: Chunk size for chunked prefill to avoid OOM. Set to -1 to disable.

Returns:
Forward loop function that takes model as argument
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token

def forward_loop(model: nn.Module) -> None:
device = next(model.parameters()).device

for sample in calibration_data:
inputs = tokenizer(
sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"]
)
inputs = {k: v.to(device) for k, v in inputs.items()}
input_ids = inputs["input_ids"].to(device)
seq_len = input_ids.shape[1]

with torch.no_grad():
if chunk_size > 0 and seq_len > chunk_size:
# Chunked prefill to avoid OOM with long sequences
past_key_values = None
for start_idx in range(0, seq_len, chunk_size):
end_idx = min(start_idx + chunk_size, seq_len)
chunk_input_ids = input_ids[:, start_idx:end_idx]

outputs = model(
chunk_input_ids,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values

# Clean up KV cache
del past_key_values
torch.cuda.empty_cache()
else:
# Full prefill without chunking
model(input_ids, use_cache=False)

return forward_loop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

torch.cuda.empty_cache() behavior without CUDA available

💡 Result:

torch.cuda.empty_cache() is only meaningful once PyTorch’s CUDA runtime + caching allocator are initialized; otherwise it won’t be able to “free GPU cache”.

What happens when “CUDA isn’t available” depends on why it isn’t available:

  1. CPU-only PyTorch build (not compiled with CUDA)
  • Calling torch.cuda.empty_cache() will trigger CUDA lazy-init and typically raises AssertionError: Torch not compiled with CUDA enabled (coming from torch.cuda._lazy_init). [3]
  1. CUDA build, but CUDA can’t be initialized at runtime (e.g., no driver / driver error / WSL/NVML issues)
  • torch.cuda.is_available() may be False, and attempting CUDA init (which empty_cache() effectively forces) can raise a RuntimeError from cudaGetDeviceCount() / initialization. [4]
  1. CUDA build + working CUDA, but you just haven’t used CUDA yet
  • torch.cuda.empty_cache() may still initialize CUDA (creating a context) and then release unoccupied cached blocks held by PyTorch’s caching allocator. It does not free memory held by live tensors. [1][2]

Practical pattern:

import torch

if torch.cuda.is_available() and torch.cuda.is_initialized():
    torch.cuda.empty_cache()

(Use is_available() to avoid CPU-only / broken-CUDA cases, and is_initialized() if you specifically want to avoid initializing CUDA just to call empty_cache().) [1][2]


🏁 Script executed:

# Check if there are other CUDA-related calls in the file
rg -A 2 -B 2 "torch.cuda" modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Repository: NVIDIA/Model-Optimizer

Length of output: 384


🏁 Script executed:

# Search for CPU-related checks or CUDA availability guards in the codebase
rg "torch.cuda.is_available|cuda.is_available" --type py -A 2 | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 4501


🏁 Script executed:

# Check if this calibration module has any documentation about CPU support
rg -B 5 -A 10 "def create_calibration_forward_loop" modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Repository: NVIDIA/Model-Optimizer

Length of output: 526


Guard CUDA cache clearing on CPU-only builds.

torch.cuda.empty_cache() raises AssertionError on CPU-only PyTorch builds. There are two unguarded calls in this function (after chunked and full prefill paths); both should be guarded with if torch.cuda.is_available():. This pattern is already established elsewhere in the codebase.

🛠️ Suggested change
-                    torch.cuda.empty_cache()
+                    if torch.cuda.is_available():
+                        torch.cuda.empty_cache()

Apply the same guard to the second instance after the full prefill path as well.

🤖 Prompt for AI Agents
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 135 - 189, The function create_calibration_forward_loop (specifically the
nested forward_loop) calls torch.cuda.empty_cache() unguarded in both the
chunked prefill branch (after deleting past_key_values) and the full prefill
branch; wrap both calls in a runtime GPU check by only calling
torch.cuda.empty_cache() when torch.cuda.is_available() to avoid AssertionError
on CPU-only PyTorch builds (i.e., add an if torch.cuda.is_available(): guard
around the existing empty_cache() invocations).

Comment on lines +192 to +255
def create_decode_calibration_forward_loop(
calibration_data: list[dict[str, Any]],
tokenizer_name_or_path: str,
num_decode_tokens: int = 10,
) -> Callable:
"""Create forward loop for decode phase calibration.

Uses flash attention for fast prefill, then switches to eager attention
for decode token generation with softmax hook measurement.

Args:
calibration_data: List of samples with 'input' and 'length' fields
tokenizer_name_or_path: HuggingFace tokenizer path
num_decode_tokens: Number of decode tokens to generate per sample

Returns:
Forward loop function that takes model as argument
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token

def forward_loop(model: nn.Module) -> None:
device = next(model.parameters()).device

for sample in calibration_data:
inputs = tokenizer(
sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"]
)
input_ids = inputs["input_ids"].to(device)

# Save original attention implementation
original_attn_impl = getattr(model.config, "_attn_implementation", "eager")

with torch.no_grad():
try:
# Step 1: Fast prefill with flash attention (no measurement)
model.config._attn_implementation = "flash_attention_2"
outputs = model(input_ids, use_cache=True)
past_key_values = outputs.past_key_values

# Step 2: Switch to eager for decode (enables softmax hook)
model.config._attn_implementation = "eager"

# Step 3: Manual decode loop for explicit control over token generation
# model.generate() method is not used here because it doesn't allow explicit control over KV cache
# Get the last token's logits and sample next token
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)

for _ in range(num_decode_tokens):
outputs = model(
next_token,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
next_token = outputs.logits[:, -1:, :].argmax(dim=-1)
finally:
# Restore original attention implementation
model.config._attn_implementation = original_attn_impl

# Clean up
del past_key_values
torch.cuda.empty_cache()

return forward_loop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py | sed -n '192,257p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3361


🏁 Script executed:

rg "torch\.cuda" modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Repository: NVIDIA/Model-Optimizer

Length of output: 147


🏁 Script executed:

rg "torch\.cuda\.is_available" --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 2206


🏁 Script executed:

rg "empty_cache" --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1942


🏁 Script executed:

# Check the context around the first empty_cache call in calibrate.py
sed -n '150,180p' modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1436


🏁 Script executed:

# Check if there are guards around empty_cache elsewhere in the same file
rg -B3 -A3 "torch\.cuda\.empty_cache\(\)" modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Repository: NVIDIA/Model-Optimizer

Length of output: 441


🏁 Script executed:

# Check how empty_cache is used in similar calibration contexts
rg -B5 -A5 "torch\.cuda\.empty_cache\(\)" modelopt/torch/utils/dataset_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 489


🏁 Script executed:

# Check the module's scope and whether it requires CUDA
head -50 modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1945


Guard torch.cuda.empty_cache() to support CPU-only execution.

The torch.cuda.empty_cache() call at line 255 will crash on CPU-only builds since the model device can be CPU. This must be guarded with torch.cuda.is_available() to ensure the decode calibration function works across all deployment scenarios.

Suggested change
             del past_key_values
-            torch.cuda.empty_cache()
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
🤖 Prompt for AI Agents
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 192 - 257, The forward_loop in create_decode_calibration_forward_loop
calls torch.cuda.empty_cache() unconditionally which will fail on CPU-only runs;
change it to only call empty_cache when CUDA is available (e.g., guard with
torch.cuda.is_available() or check device.type == "cuda") after deleting
past_key_values so CPU-only execution doesn't crash while preserving GPU cleanup
behavior.

Comment on lines +249 to +276
def _enable_calibration_mode(self, modules: list[nn.Module]):
"""Enable calibration mode on sparse attention modules."""
for idx, module in enumerate(modules):
# Create stats manager if needed
if not module._stats_manager:
module._stats_manager = SparseAttentionStatsManager(
module_name=f"sparse_attn_{idx}", enabled=True
)
else:
# Re-enable if disabled
module._stats_manager.enabled = True

# Enable calibration mode with fresh stats
module._stats_manager.set_calibration_mode(enabled=True, reset_history=True)
module._sparse_method_instance.set_calibration_mode(True)

def _disable_calibration_mode(self, modules: list[nn.Module]):
"""Disable calibration mode (but keep stats enabled if collect_stats=True)."""
for module in modules:
if module._stats_manager:
module._stats_manager.set_calibration_mode(enabled=False)

module._sparse_method_instance.set_calibration_mode(False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Restore stats-manager state after calibration to avoid permanent overhead.

_enable_calibration_mode creates/enables a stats manager even when collect_stats was off, but _disable_calibration_mode never restores the previous state. That leaves stats collection enabled after calibration, changing runtime behavior and adding overhead. Consider stashing and restoring the prior stats-manager state.

🛠 Suggested fix
     def _enable_calibration_mode(self, modules: list[nn.Module]):
         """Enable calibration mode on sparse attention modules."""
         for idx, module in enumerate(modules):
+            # Preserve previous stats-manager state for restoration
+            if not hasattr(module, "_calib_prev_stats_state"):
+                module._calib_prev_stats_state = (
+                    module._stats_manager,
+                    module._stats_manager.enabled if module._stats_manager else None,
+                )
             # Create stats manager if needed
             if not module._stats_manager:
                 module._stats_manager = SparseAttentionStatsManager(
                     module_name=f"sparse_attn_{idx}", enabled=True
                 )
@@
     def _disable_calibration_mode(self, modules: list[nn.Module]):
         """Disable calibration mode (but keep stats enabled if collect_stats=True)."""
         for module in modules:
             if module._stats_manager:
                 module._stats_manager.set_calibration_mode(enabled=False)

             module._sparse_method_instance.set_calibration_mode(False)
+
+            prev_mgr, prev_enabled = getattr(module, "_calib_prev_stats_state", (None, None))
+            if prev_mgr is None:
+                module._stats_manager = None
+            else:
+                module._stats_manager = prev_mgr
+                if prev_enabled is not None:
+                    module._stats_manager.enabled = prev_enabled
+            if hasattr(module, "_calib_prev_stats_state"):
+                delattr(module, "_calib_prev_stats_state")
🤖 Prompt for AI Agents
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 249 - 272, Enable calibration should stash the stats-manager's prior
existence and enabled state on the module so we can restore it later; modify
_enable_calibration_mode to record whether module._stats_manager existed and its
previous enabled flag (e.g., store tuple in module._prev_stats_manager_state or
similar) before creating/enabling a SparseAttentionStatsManager and calling
set_calibration_mode(True), and set a marker if you created the manager; then
modify _disable_calibration_mode to consult that stash and restore the original
state: if the manager was newly created by the calibrator, remove it (or set to
None), otherwise restore module._stats_manager.enabled to the previous value and
call set_calibration_mode(False); use the existing symbols
_enable_calibration_mode, _disable_calibration_mode, module._stats_manager,
SparseAttentionStatsManager, and set_calibration_mode to locate and implement
these changes.

Comment on lines 34 to 46
def test_valid_config(self):
"""Test creating valid config."""
config = SparseAttentionAttributeConfig(
method="flash_skip_softmax",
threshold=1e-4,
br=128,
bc=128,
enable=True,
)
assert config.method == "flash_skip_softmax"
assert config.threshold == 1e-4
assert config.br == 128
assert config.bc == 128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate the config.py file and examine the threshold validator
fd config.py -x cat -n {} | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 7093


🏁 Script executed:

# Search for the validate_threshold validator
rg -n "validate_threshold|threshold" . -A 3 -B 1 --type py | grep -E "(config\.py|sparse)" | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 10117


🏁 Script executed:

# Look at the specific lines in config.py mentioned in the review
find . -name "config.py" -path "*/sparse*" -o -name "config.py" -path "*/attention*" | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 120


🏁 Script executed:

# Examine the config.py file to see the threshold field definition and validator
cat ./modelopt/torch/sparsity/attention_sparsity/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 15398


Test passes scalar threshold but validator requires dict format.

The test_valid_config test passes threshold=1e-4 (a scalar float), but the validate_threshold validator in config.py (lines 121-142) explicitly requires a dict with 'prefill' and/or 'decode' keys. The validator checks isinstance(v, dict) and raises ValueError if the input is not a dict. This test will fail with a ValidationError at runtime.

🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py`
around lines 34 - 46, test_valid_config is passing threshold as a scalar while
the validator validate_threshold on SparseAttentionAttributeConfig expects a
dict with 'prefill' and/or 'decode'; update the test to pass a dict (e.g.
threshold={"prefill": 1e-4, "decode": 1e-4} or at least one of those keys) so
the value satisfies validate_threshold and the assertions remain valid for
SparseAttentionAttributeConfig.

Comment on lines 67 to 79
def test_threshold_validation_range(self):
"""Test threshold must be in range (0, 1)."""
with pytest.raises(ValidationError, match="Threshold must be in range"):
SparseAttentionAttributeConfig(threshold=0)

with pytest.raises(ValidationError, match="Threshold must be in range"):
SparseAttentionAttributeConfig(threshold=-0.1)

with pytest.raises(ValidationError, match="Threshold must be in range"):
SparseAttentionAttributeConfig(threshold=1.0)

with pytest.raises(ValidationError, match="Threshold must be in range"):
SparseAttentionAttributeConfig(threshold=1.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Threshold validation tests pass scalar values but validator expects dict.

Similar to the earlier issue, these tests pass scalar threshold values (0, -0.1, 1.0, 1.5) expecting range validation errors, but the validator should first fail with "Threshold must be a dict" before checking the range. The test match strings expect "Threshold must be in range" which won't be reached if the dict type check fails first.

🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py`
around lines 67 - 79, The tests in test_threshold_validation_range call
SparseAttentionAttributeConfig with scalar thresholds but assert the validator
raises "Threshold must be in range"; since the validator first enforces that
threshold is a dict, update the expectations to match the dict-type validation
error instead of range errors—i.e., in test_threshold_validation_range change
the pytest.raises match from "Threshold must be in range" to the message used by
the dict-type check (e.g., "Threshold must be a dict") for the scalar inputs
when constructing SparseAttentionAttributeConfig.

Comment on lines 99 to 111
def test_threshold_validation_type(self):
"""Test threshold type validation."""
with pytest.raises(ValidationError, match="Input should be a valid"):
SparseAttentionAttributeConfig(threshold="invalid")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Test expects wrong error message for string threshold.

The test passes threshold="invalid" and expects a match on "Input should be a valid". However, based on the custom validate_threshold validator, the error message should be "Threshold must be a dict with 'prefill' and/or 'decode' keys". The current match pattern likely relies on Pydantic's default type validation running before the custom validator, which may not be the case.

🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py`
around lines 99 - 102, The test test_threshold_validation_type should assert the
actual validator message from SparseAttentionAttributeConfig rather than a
generic Pydantic message: update the pytest.raises match to the custom validator
text "Threshold must be a dict with 'prefill' and/or 'decode' keys" (or change
the invalid input to a type that triggers the default Pydantic type error) so it
aligns with the validate_threshold validator behavior in
SparseAttentionAttributeConfig.

Comment on lines +18 to +21
import pytest

pytest.importorskip("transformers")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

❓ Verification inconclusive

Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

     1	# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
     2	# SPDX-License-Identifier: Apache-2.0
     3	#
     4	# Licensed under the Apache License, Version 2.0 (the "License");
     5	# you may not use this file except in compliance with the License.
     6	# You may obtain a copy of the License at
     7	#
     8	# http://www.apache.org/licenses/LICENSE-2.0
     9	#
    10	# Unless required by applicable law or agreed to in writing, software
    11	# distributed under the License is distributed on an "AS IS" BASIS,
    12	# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13	# See the License for the specific language governing permissions and
    14	# limitations under the License.
    15	
    16	"""Unit tests for SparseAttentionStatsManager."""
    17	
    18	import pytest
    19	
    20	pytest.importorskip("transformers")
    21	
    22	from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager
    23	
    24	
    25	class TestStatsManagerInitialization:
    26	    """Test stats manager initialization."""
    27	
    28	    def test_initialization_defaults(self):
    29	        """Test default initialization."""
    30	        manager = SparseAttentionStatsManager(module_name="test_module")
    31	
    32	        assert manager.module_name == "test_module"
    33	        assert manager.enabled is True
    34	        assert manager.calibration_mode is False
    35	        assert manager.aggregated_stats["total_calls"] == 0
    36	        assert manager.aggregated_stats["total_blocks"] == 0
    37	        assert manager.aggregated_stats["sparse_blocks"] == 0
    38	        assert manager.per_sample_stats == []
    39	
    40	    def test_initialization_disabled(self):
    41	        """Test initialization with disabled stats."""
    42	        manager = SparseAttentionStatsManager(module_name="test_module", enabled=False)
    43	
    44	        assert manager.enabled is False
    45	        assert manager.calibration_mode is False
    46	
    47	    def test_initialization_custom_name(self):
    48	        """Test initialization with custom module name."""
    49	        manager = SparseAttentionStatsManager(module_name="custom.attention.module")
    50	
    51	        assert manager.module_name == "custom.attention.module"
    52	
    53	
    54	class TestStatsCollection:
    55	    """Test statistics collection functionality."""
    56	
    57	    def test_collect_stats_enabled(self):
    58	        """Test collecting stats when enabled."""
    59	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
    60	
    61	        stats = {
    62	            "sparsity": 0.5,
    63	            "phase": "prefill",
    64	            "total_blocks": 100,
    65	            "sparse_blocks": 50,
    66	            "sample_length": 1024,
    67	        }
    68	
    69	        manager.collect(stats)
    70	
    71	        assert manager.aggregated_stats["total_calls"] == 1
    72	        assert manager.aggregated_stats["total_blocks"] == 100
    73	        assert manager.aggregated_stats["sparse_blocks"] == 50
    74	        assert manager.aggregated_stats["phase_counts"]["prefill"] == 1
    75	        assert manager.aggregated_stats["phase_counts"]["decode"] == 0
    76	
    77	    def test_collect_stats_disabled(self):
    78	        """Test that collect() is no-op when disabled."""
    79	        manager = SparseAttentionStatsManager(module_name="test", enabled=False)
    80	
    81	        stats = {
    82	            "sparsity": 0.5,
    83	            "phase": "prefill",
    84	            "total_blocks": 100,
    85	            "sparse_blocks": 50,
    86	        }
    87	
    88	        manager.collect(stats)
    89	
    90	        # Should remain at initial values
    91	        assert manager.aggregated_stats["total_calls"] == 0
    92	        assert manager.aggregated_stats["total_blocks"] == 0
    93	        assert manager.aggregated_stats["sparse_blocks"] == 0
    94	
    95	    def test_collect_multiple_calls(self):
    96	        """Test accumulation over multiple collect calls."""
    97	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
    98	
    99	        # Collect multiple times
   100	        for i in range(5):
   101	            stats = {
   102	                "sparsity": 0.5,
   103	                "phase": "prefill",
   104	                "total_blocks": 100,
   105	                "sparse_blocks": 50,
   106	            }
   107	            manager.collect(stats)
   108	
   109	        assert manager.aggregated_stats["total_calls"] == 5
   110	        assert manager.aggregated_stats["total_blocks"] == 500
   111	        assert manager.aggregated_stats["sparse_blocks"] == 250
   112	        assert manager.aggregated_stats["phase_counts"]["prefill"] == 5
   113	
   114	    def test_collect_different_phases(self):
   115	        """Test phase counting."""
   116	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   117	
   118	        # Collect prefill stats
   119	        manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50})
   120	        manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50})
   121	
   122	        # Collect decode stats
   123	        manager.collect({"phase": "decode", "total_blocks": 10, "sparse_blocks": 5})
   124	
   125	        assert manager.aggregated_stats["phase_counts"]["prefill"] == 2
   126	        assert manager.aggregated_stats["phase_counts"]["decode"] == 1
   127	        assert manager.aggregated_stats["phase_counts"]["unknown"] == 0
   128	
   129	
   130	class TestCalibrationMode:
   131	    """Test calibration mode functionality."""
   132	
   133	    def test_calibration_mode_per_sample_collection(self):
   134	        """Test that calibration mode stores per-sample stats."""
   135	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   136	
   137	        # Enable calibration mode
   138	        manager.set_calibration_mode(enabled=True)
   139	
   140	        stats = {
   141	            "sparsity": 0.5,
   142	            "phase": "prefill",
   143	            "total_blocks": 100,
   144	            "sparse_blocks": 50,
   145	            "sample_length": 1024,
   146	        }
   147	
   148	        manager.collect(stats)
   149	
   150	        # Should store in per_sample_stats
   151	        assert len(manager.per_sample_stats) == 1
   152	        assert manager.per_sample_stats[0]["module"] == "test"
   153	        assert manager.per_sample_stats[0]["sparsity"] == 0.5
   154	        assert manager.per_sample_stats[0]["sample_length"] == 1024
   155	        assert manager.per_sample_stats[0]["phase"] == "prefill"
   156	
   157	    def test_calibration_mode_off(self):
   158	        """Test that per-sample stats are not collected when calibration mode is off."""
   159	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   160	        # Calibration mode is off by default
   161	
   162	        stats = {"sparsity": 0.5, "phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}
   163	
   164	        manager.collect(stats)
   165	
   166	        # Should NOT store in per_sample_stats
   167	        assert len(manager.per_sample_stats) == 0
   168	
   169	        # But should still aggregate
   170	        assert manager.aggregated_stats["total_calls"] == 1
   171	
   172	    def test_set_calibration_mode_with_reset(self):
   173	        """Test set_calibration_mode with reset_history=True."""
   174	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   175	
   176	        # Collect some stats in calibration mode
   177	        manager.set_calibration_mode(enabled=True)
   178	        manager.collect(
   179	            {
   180	                "sparsity": 0.5,
   181	                "phase": "prefill",
   182	                "total_blocks": 100,
   183	                "sparse_blocks": 50,
   184	                "sample_length": 1024,
   185	            }
   186	        )
   187	        assert len(manager.per_sample_stats) == 1
   188	
   189	        # Re-enable with reset
   190	        manager.set_calibration_mode(enabled=True, reset_history=True)
   191	        assert len(manager.per_sample_stats) == 0  # Should be cleared
   192	
   193	    def test_set_calibration_mode_without_reset(self):
   194	        """Test set_calibration_mode with reset_history=False."""
   195	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   196	
   197	        # Collect some stats
   198	        manager.set_calibration_mode(enabled=True)
   199	        manager.collect(
   200	            {
   201	                "sparsity": 0.5,
   202	                "phase": "prefill",
   203	                "total_blocks": 100,
   204	                "sparse_blocks": 50,
   205	                "sample_length": 1024,
   206	            }
   207	        )
   208	        assert len(manager.per_sample_stats) == 1
   209	
   210	        # Disable without reset
   211	        manager.set_calibration_mode(enabled=False, reset_history=False)
   212	        assert len(manager.per_sample_stats) == 1  # Should be preserved
   213	
   214	
   215	class TestGetSummary:
   216	    """Test get_summary() functionality."""
   217	
   218	    def test_get_summary_with_data(self):
   219	        """Test get_summary returns correct averages."""
   220	        manager = SparseAttentionStatsManager(module_name="test_module", enabled=True)
   221	
   222	        # Collect stats
   223	        manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 30})
   224	        manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50})
   225	
   226	        summary = manager.get_summary()
   227	
   228	        assert summary["module"] == "test_module"
   229	        assert summary["total_calls"] == 2
   230	        # Average sparsity: (30+50) / (100+100) = 80/200 = 0.4
   231	        assert summary["average_sparsity"] == 0.4
   232	        assert summary["phase_distribution"]["prefill"] == 2
   233	
   234	    def test_get_summary_no_data(self):
   235	        """Test get_summary with no collected data."""
   236	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   237	
   238	        summary = manager.get_summary()
   239	
   240	        assert summary["module"] == "test"
   241	        assert summary["total_calls"] == 0
   242	        assert summary["average_sparsity"] == 0.0
   243	        assert summary["phase_distribution"]["prefill"] == 0
   244	
   245	    def test_get_summary_zero_blocks(self):
   246	        """Test get_summary when total_blocks is zero."""
   247	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   248	
   249	        # Collect stats with zero blocks
   250	        manager.collect({"phase": "prefill", "total_blocks": 0, "sparse_blocks": 0})
   251	
   252	        summary = manager.get_summary()
   253	
   254	        assert summary["average_sparsity"] == 0.0  # Should handle division by zero
   255	
   256	
   257	class TestGetCalibrationStats:
   258	    """Test get_calibration_stats() functionality."""
   259	
   260	    def test_get_calibration_stats(self):
   261	        """Test retrieving per-sample calibration stats."""
   262	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   263	        manager.set_calibration_mode(enabled=True)
   264	
   265	        # Collect multiple samples
   266	        for i in range(3):
   267	            manager.collect(
   268	                {
   269	                    "sparsity": 0.3 + i * 0.1,
   270	                    "phase": "prefill",
   271	                    "total_blocks": 100,
   272	                    "sparse_blocks": 30,
   273	                    "sample_length": 1024 + i * 512,
   274	                }
   275	            )
   276	
   277	        calib_stats = manager.get_calibration_stats()
   278	
   279	        assert len(calib_stats) == 3
   280	        assert calib_stats[0]["sparsity"] == 0.3
   281	        assert calib_stats[1]["sparsity"] == 0.4
   282	        assert calib_stats[2]["sparsity"] == 0.5
   283	
   284	    def test_get_calibration_stats_empty(self):
   285	        """Test get_calibration_stats when no calibration data."""
   286	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   287	
   288	        calib_stats = manager.get_calibration_stats()
   289	
   290	        assert calib_stats == []
   291	
   292	
   293	class TestReset:
   294	    """Test reset functionality."""
   295	
   296	    def test_reset(self):
   297	        """Test reset() clears all statistics."""
   298	        manager = SparseAttentionStatsManager(module_name="test", enabled=True)
   299	        manager.set_calibration_mode(enabled=True)
   300	
   301	        # Collect some stats
   302	        manager.collect(
   303	            {
   304	                "sparsity": 0.5,
   305	                "phase": "prefill",
   306	                "total_blocks": 100,
   307	                "sparse_blocks": 50,
   308	                "sample_length": 1024,
   309	            }
   310	        )
   311	        manager.collect(
   312	            {
   313	                "sparsity": 0.3,
   314	                "phase": "decode",
   315	                "total_blocks": 10,
   316	                "sparse_blocks": 3,
   317	                "sample_length": 128,
   318	            }
   319	        )
   320	
   321	        # Verify stats exist
   322	        assert manager.aggregated_stats["total_calls"] == 2
   323	        assert len(manager.per_sample_stats) == 2
   324	
   325	        # Reset
   326	        manager.reset()
   327	
   328	        # All stats should be cleared
   329	        assert manager.aggregated_stats["total_calls"] == 0
   330	        assert manager.aggregated_stats["total_blocks"] == 0
   331	        assert manager.aggregated_stats["sparse_blocks"] == 0
   332	        assert manager.per_sample_stats == []
   333	        assert manager.aggregated_stats["phase_counts"]["prefill"] == 0
   334	        assert manager.aggregated_stats["phase_counts"]["decode"] == 0

Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/stats_manager.py | head -100

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

     1	# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
     2	# SPDX-License-Identifier: Apache-2.0
     3	#
     4	# Licensed under the Apache License, Version 2.0 (the "License");
     5	# you may not use this file except in compliance with the License.
     6	# You may obtain a copy of the License at
     7	#
     8	# http://www.apache.org/licenses/LICENSE-2.0
     9	#
    10	# Unless required by applicable law or agreed to in writing, software
    11	# distributed under the License is distributed on an "AS IS" BASIS,
    12	# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13	# See the License for the specific language governing permissions and
    14	# limitations under the License.
    15	
    16	"""Statistics manager for sparse attention modules."""
    17	
    18	
    19	class SparseAttentionStatsManager:
    20	    """Centralized statistics manager for sparse attention.
    21	
    22	    This class is the single source of truth for all statistics collection
    23	    in sparse attention modules. It handles both runtime aggregation and
    24	    per-sample calibration statistics.
    25	
    26	    Design principles:
    27	    - Single responsibility: only stats management
    28	    - No computation: receives pre-computed stats from methods
    29	    - Optional: can be None if stats collection disabled
    30	    - Zero overhead when disabled
    31	    """
    32	
    33	    def __init__(self, module_name: str, enabled: bool = True):
    34	        """Initialize stats manager.
    35	
    36	        Args:
    37	            module_name: Name of the module this manager is attached to
    38	            enabled: Whether stats collection is enabled
    39	        """
    40	        self.module_name = module_name
    41	        self.enabled = enabled
    42	        self.calibration_mode = False
    43	
    44	        # Aggregated stats (running totals across all forward passes)
    45	        self.aggregated_stats: dict = {
    46	            "total_calls": 0,
    47	            "total_blocks": 0,
    48	            "sparse_blocks": 0,
    49	            "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0},
    50	        }
    51	
    52	        # Per-sample stats (only populated during calibration)
    53	        self.per_sample_stats: list[dict] = []
    54	
    55	    def collect(self, stats: dict):
    56	        """Collect statistics from a single forward pass.
    57	
    58	        Args:
    59	            stats: Dictionary containing statistics from method computation.
    60	                Expected keys: sparsity, phase, total_blocks, sparse_blocks,
    61	                sample_length (optional)
    62	        """
    63	        if not self.enabled:
    64	            return
    65	
    66	        # Update aggregated stats
    67	        self.aggregated_stats["total_calls"] += 1
    68	        self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0)
    69	        self.aggregated_stats["sparse_blocks"] += stats.get("sparse_blocks", 0)
    70	
    71	        phase = stats.get("phase", "unknown")
    72	        if phase in self.aggregated_stats["phase_counts"]:
    73	            self.aggregated_stats["phase_counts"][phase] += 1
    74	
    75	        # In calibration mode, store per-sample stats
    76	        if self.calibration_mode:
    77	            self.per_sample_stats.append(
    78	                {
    79	                    "module": self.module_name,
    80	                    "sparsity": stats.get("sparsity", 0.0),
    81	                    "sample_length": stats.get("sample_length", 0),
    82	                    "phase": phase,
    83	                }
    84	            )
    85	
    86	    def get_summary(self) -> dict:
    87	        """Get aggregated statistics summary.
    88	
    89	        Returns:
    90	            Dictionary with module name, total calls, average sparsity,
    91	            and phase distribution.
    92	        """
    93	        total_blocks = self.aggregated_stats["total_blocks"]
    94	        if total_blocks > 0:
    95	            avg_sparsity = self.aggregated_stats["sparse_blocks"] / total_blocks
    96	        else:
    97	            avg_sparsity = 0.0
    98	
    99	        return {
   100	            "module": self.module_name,

Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/stats_manager.py | tail -50

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

    98	
    99	        return {
   100	            "module": self.module_name,
   101	            "total_calls": self.aggregated_stats["total_calls"],
   102	            "average_sparsity": avg_sparsity,
   103	            "phase_distribution": self.aggregated_stats["phase_counts"].copy(),
   104	        }
   105	
   106	    def set_calibration_mode(self, enabled: bool, reset_history: bool = True):
   107	        """Enable or disable calibration mode.
   108	
   109	        In calibration mode, per-sample statistics are stored for detailed
   110	        analysis. Otherwise, only aggregated stats are maintained.
   111	
   112	        Args:
   113	            enabled: Whether to enable calibration mode
   114	            reset_history: Whether to clear per_sample_stats when enabling
   115	        """
   116	        self.calibration_mode = enabled
   117	        if enabled and reset_history:
   118	            self.per_sample_stats = []
   119	
   120	    def reset(self):
   121	        """Reset all statistics to initial state."""
   122	        self.aggregated_stats = {
   123	            "total_calls": 0,
   124	            "total_blocks": 0,
   125	            "sparse_blocks": 0,
   126	            "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0},
   127	        }
   128	        self.per_sample_stats = []
   129	
   130	    def get_calibration_stats(self, phase: str | None = None) -> list[dict]:
   131	        """Get per-sample calibration statistics, optionally filtered by phase.
   132	
   133	        Note: Returns historical stats collected while calibration_mode was enabled.
   134	        Stats remain accessible even after calibration_mode is disabled.
   135	        New stats are only collected when calibration_mode is True.
   136	
   137	        Args:
   138	            phase: Optional phase to filter by ('prefill' or 'decode').
   139	                   If None, returns all stats.
   140	
   141	        Returns:
   142	            List of per-sample statistics dictionaries.
   143	            Empty list if no stats were collected or no stats match the phase.
   144	        """
   145	        if phase is None:
   146	            return self.per_sample_stats
   147	        return [s for s in self.per_sample_stats if s.get("phase") == phase]

Script executed:

# Check if transformers is used anywhere in this test file
rg "transformers" tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

pytest.importorskip("transformers")

Script executed:

# Check if transformers is imported or used in the stats_manager.py itself
rg "transformers" modelopt/torch/sparsity/attention_sparsity/

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

modelopt/torch/sparsity/attention_sparsity/model_sparsify.py:                    from transformers import AutoModelForCausalLM
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py:import transformers
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py:        if isinstance(model, transformers.PreTrainedModel):
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py:    if not isinstance(model, transformers.PreTrainedModel):
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py:from transformers import AutoTokenizer
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py:            from transformers import AutoTokenizer

Script executed:

# Check if there are other test files in the same directory and their importorskip patterns
find tests/unit/torch/sparsity/attention_sparsity/ -name "*.py" -type f | xargs -I {} sh -c 'echo "=== {} ===" && head -25 {}'

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

=== tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for sparse attention mode registry."""

import pytest

pytest.importorskip("transformers")

from modelopt.torch.opt.mode import _ModeRegistryCls
from modelopt.torch.sparsity.attention_sparsity.mode import SparseAttentionModeRegistry


=== tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for FlashSkipSoftmax method internals."""

import pytest
import torch

pytest.importorskip("transformers")

from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax


=== tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for sparse attention calibration."""

import pytest

pytest.importorskip("transformers")

import numpy as np
from _test_utils.torch.sparsity.sparse_attention_common import (
    SimpleAttentionModel,
    SimpleTransformerEncoder,
=== tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test sparse attention configuration validation."""

import pytest
from pydantic import ValidationError

pytest.importorskip("transformers")

from modelopt.torch.sparsity.attention_sparsity.config import (
    SKIP_SOFTMAX_DEFAULT,
    FlashSkipSoftmaxConfig,
=== tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for sparse attention conversion and replacement."""

import pytest

pytest.importorskip("transformers")

import torch.nn as nn
from _test_utils.torch.sparsity.sparse_attention_common import (
    FLASH_SKIP_SOFTMAX_DEFAULT_CFG,
    SimpleAttentionModel,
=== tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for SparseAttentionStatsManager."""

import pytest

pytest.importorskip("transformers")

from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager


class TestStatsManagerInitialization:
=== tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for threshold calibration functionality."""

import pytest

pytest.importorskip("transformers")

from _test_utils.torch.sparsity.sparse_attention_common import SimpleAttentionModel

from modelopt.torch.sparsity.attention_sparsity import sparsify
from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax

Remove unnecessary transformers import skip.

The pytest.importorskip("transformers") is unnecessary—SparseAttentionStatsManager and its tests have no dependency on transformers. This skip will prevent test execution on minimal installs without providing any benefit. Remove it.

Suggested change
import pytest

-pytest.importorskip("transformers")

from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import pytest
pytest.importorskip("transformers")
import pytest
from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager
🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py` around
lines 18 - 21, Remove the unnecessary test-level skip by deleting the
pytest.importorskip("transformers") line in the test module; the tests for
SparseAttentionStatsManager do not depend on transformers so simply remove that
import-or-skip call near the top of
tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py and keep the
rest of the file (ensuring SparseAttentionStatsManager tests run on minimal
installs).

Comment on lines 52 to 84
def test_dynamic_calibrated_threshold(self):
"""Test threshold info for calibrated dynamic threshold."""
method = FlashSkipSoftmax(
method_config={
"threshold": {"prefill": 0.001, "decode": 0.0001},
"br": 128,
"bc": 128,
"backend": "pytorch",
"is_causal": True,
}
)

# Simulate calibration setting per-phase scale factors
method.threshold_scale_factor = {"prefill": 437.5, "decode": 500.0}

info = method.get_threshold_info()

assert info["type"] == "dynamic"
assert info["scale_factors"] == {"prefill": 437.5, "decode": 500.0}
assert info["formula"] == "λ[phase] / length"
assert "phases" in info
assert "prefill" in info["phases"]
assert "decode" in info["phases"]
# Check example thresholds for prefill
prefill_examples = info["phases"]["prefill"]["example_thresholds"]
assert abs(prefill_examples[1024] - 437.5 / 1024) < 1e-6
assert abs(prefill_examples[2048] - 437.5 / 2048) < 1e-6

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Dynamic threshold tests target an outdated threshold-info contract.

FlashSkipSoftmax.get_threshold_info() now reports calibration via calibration_params/target_sparse_ratio and a different type payload. These tests still set threshold_scale_factor and assert type == "dynamic" with scale_factors and "λ=" formatting, so they’ll fail unless the contract is aligned. Please update the tests to match the new payload (or restore legacy fields in the method).

🛠 Example update for dynamic-threshold expectations
-        # Simulate calibration setting per-phase scale factors
-        method.threshold_scale_factor = {"prefill": 437.5, "decode": 500.0}
+        # Simulate calibration params + target sparsity
+        method.calibration_params = {
+            "prefill": {"k": 437.5, "p": 1.0},
+            "decode": {"k": 500.0, "p": 1.0},
+        }
+        method.target_sparse_ratio = {"prefill": 0.0, "decode": 0.0}

-        assert info["type"] == "dynamic"
-        assert info["scale_factors"] == {"prefill": 437.5, "decode": 500.0}
-        assert info["formula"] == "λ[phase] / length"
+        assert info["type"] == "dynamic_calibrated"
+        assert info["phases"]["prefill"]["scale_factor"] == pytest.approx(437.5)
+        assert info["phases"]["decode"]["scale_factor"] == pytest.approx(500.0)
+        assert "formula" in info

Also applies to: 136-170, 233-269

🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py` around
lines 52 - 79, The test expects the old threshold-info contract from
FlashSkipSoftmax.get_threshold_info (checking type=="dynamic",
threshold_scale_factor/scale_factors, and formula "λ[phase] / length"), but the
implementation now returns calibration_params and target_sparse_ratio with a
different type/payload; update the tests to assert the new contract: call
FlashSkipSoftmax.get_threshold_info(), assert the returned info uses the new
keys (e.g., info["type"] matches the new value, info["calibration_params"]
contains per-phase calibration data, and info["target_sparse_ratio"] or
equivalent exists), replace assertions that reference
threshold_scale_factor/scale_factors and the old formula with checks against
info["calibration_params"] structure and example thresholds derived from
calibration_params (or alternatively restore legacy fields in
FlashSkipSoftmax.get_threshold_info to emit scale_factors for backward
compatibility). Ensure references to FlashSkipSoftmax.get_threshold_info and the
instance attribute threshold_scale_factor are updated/removed across the other
affected test blocks (lines ~136-170 and ~233-269).

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 33d0025 to 31e8b2d Compare January 28, 2026 02:46
@kaix-nv kaix-nv requested a review from rohansjoshi January 28, 2026 02:47
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 31e8b2d to 2fc1734 Compare January 28, 2026 07:05
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 2fc1734 to 0553ec6 Compare January 30, 2026 05:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants