-
Notifications
You must be signed in to change notification settings - Fork 247
[OMNIML-2850] [3/n] Adds sparse attention calibration #538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #538 +/- ##
==========================================
- Coverage 73.38% 73.25% -0.13%
==========================================
Files 193 199 +6
Lines 19893 20713 +820
==========================================
+ Hits 14598 15174 +576
- Misses 5295 5539 +244 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
8c7ee86 to
da6f627
Compare
525a119 to
c9d7008
Compare
7727793 to
2864629
Compare
2864629 to
ca7e24e
Compare
|
@kaix-nv github is showing 7000+ lines of code as part of this PR. Is that accurate? |
3474b6f to
74a29ea
Compare
examples/llm_eval/mmlu.py
Outdated
| # Force eager attention if sparse attention is requested | ||
| if sparse_cfg: | ||
| kwargs["attn_implementation"] = "eager" | ||
| warnings.warn( | ||
| "Sparse attention requires attn_implementation='eager'. " | ||
| "Forcing eager attention implementation." | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can e move this to sparsity/plugins/hugginface ? We should detect if this is a HF model and if yes apply this (see
| AutoQuantizeGradientSearcher.register_custom_support( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved this check to hugginface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as https://github.com/NVIDIA/Model-Optimizer/pull/538/files#r2646356349 and avoid repeated attention modification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed this check.
| from .dataset import RulerDatasetBuilder | ||
|
|
||
|
|
||
| def _extract_tokenizer_from_model(model: nn.Module) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this works only work Huggingface transformers, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, Megatron support (e.g., adding forward_loop) will be added in future PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this sparsity/plugins?
Otherwise this change will make transformers library a required dependency of ModelOpt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use local imports of third party libraries wherever necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve switched to local imports. Since most of the dataset construction logic can be reused by Megatron, I’d prefer to keep this file under the calibration folder.
| # For causal attention, only count lower triangle blocks (including diagonal) | ||
| num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2 | ||
| total_valid_blocks = batch_size * num_heads * num_causal_blocks | ||
| density = float(block_mask.sum()) / total_valid_blocks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we keep this as a torch tensor? a float(tensor_in_gpu) causes unneseccary CPU-GPU sync
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. It’s been fixed.
| density = float(block_mask.sum()) / total_valid_blocks | ||
| total_blocks = num_causal_blocks | ||
| else: | ||
| density = float(block_mask.sum() / block_mask.numel()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s been fixed.
jy-yuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed the PR, and the calibration logic looks correct (dynamic thresholding via regression and phase separation seems solid). I tested the code and verified that the unit tests pass locally.
One small observation regarding dependencies: I noticed I had to manually install nltk and wonderwords to run the tests/calibration. It seems they are currently added to dev-test in setup.py, so they aren't included in pip install nvidia-modelopt[all]. If the RULER-based calibration is intended to be a supported feature for users (i.e., when they don't provide a custom forward loop), we might want to consider moving these to a user-facing optional dependency group (like calibration or hf) or catching the ModuleNotFoundError to suggest installation.
c1692d4 to
e65d3e1
Compare
Thanks Jiayi. Good catch. I’ve moved the dependency to the |
30a9794 to
ed213d9
Compare
130f1e9 to
a3dcd9d
Compare
a3dcd9d to
33d0025
Compare
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the
📝 WalkthroughWalkthroughThis pull request introduces comprehensive sparse attention calibration support for PyTorch LLMs. It adds a calibration framework with dynamic threshold computation using RULER datasets, updates sparse attention configuration to support per-phase thresholds (prefill/decode), integrates statistics collection into sparse modules, and extends example workflows with sparse attention options. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User/CLI
participant ModelOpt as ModelOpt Sparsity
participant DataBuilder as RULER DatasetBuilder
participant Calibrator as DynamicThresholdCalibrator
participant Module as SparseAttentionModule
participant StatsManager as StatsManager
User->>ModelOpt: sparsify(model, config with calibration)
ModelOpt->>Module: apply sparse attention
ModelOpt->>ModelOpt: extract calibration config
alt Calibration enabled
ModelOpt->>DataBuilder: build_calibration_dataset()
DataBuilder-->>ModelOpt: dataset samples (varying lengths)
ModelOpt->>Calibrator: calibrate(model, forward_loop, phase="prefill")
Calibrator->>Module: set_calibration_mode(True)
Calibrator->>StatsManager: collect(stats)
StatsManager-->>Calibrator: per-sample sparsity data
Calibrator->>Calibrator: fit inverse power model (k, p)
Calibrator-->>ModelOpt: calibration_results (prefill)
ModelOpt->>Calibrator: calibrate(model, forward_loop, phase="decode")
Calibrator->>Module: set_calibration_mode(True)
Calibrator->>StatsManager: collect(stats)
StatsManager-->>Calibrator: per-sample sparsity data
Calibrator->>Calibrator: fit inverse power model (k, p)
Calibrator-->>ModelOpt: calibration_results (decode)
ModelOpt->>Module: apply calibrated thresholds
Module->>Module: set threshold_scale_factor (dynamic)
end
ModelOpt-->>User: calibrated model with sparse attention
Estimated code review effort🎯 5 (Critical) | ⏱️ ~110 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Tip 🧪 Unit Test Generation v2 is now available!We have significantly improved our unit test generation capabilities. To enable: Add this to your reviews:
finishing_touches:
unit_tests:
enabled: trueTry it out by using the Have feedback? Share your thoughts on our Discord thread! Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 18
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
examples/llm_eval/modeling.py (1)
175-194: Add error handling for unsupported attention implementations.Lines 192–193 and 247–248 pass
attn_implementationtofrom_pretrained()—supported since Transformers v4.36.0 (project requires ≥4.55.0). However, not all model architectures support all attention backends. If a model doesn't support the requested backend ("eager", "sdpa", or "flash_attention_2"),from_pretrained()may error or silently fall back. Add a try/except with a clear error message indicating which attention implementation is unsupported for the given model.examples/llm_eval/lm_eval_hf.py (1)
76-106: Early return on pre-quantized models skips sparse attention.If the model is already quantized (line 76-79), the function returns immediately, bypassing the sparse attention application (lines 99-106). This means users cannot apply sparse attention to pre-quantized models via this path.
If this is intentional, consider documenting this behavior. If sparse attention should still be applicable to pre-quantized models, the early return should be restructured.
🔧 Suggested fix to allow sparse attention on pre-quantized models
if is_quantized(model_obj.model): # return if model is already quantized warnings.warn("Skipping quantization: model is already quantized.") - return model_obj + elif quant_cfg: - - if quant_cfg: if not calib_batch_size: calib_batch_size = model_obj.batch_size quantize_model( model=model_obj, quant_cfg=quant_cfg.split(",") if auto_quantize_bits is not None else quant_cfg, tokenizer=model_obj.tokenizer, batch_size=calib_batch_size, calib_size=calib_size, auto_quantize_bits=auto_quantize_bits, auto_quantize_method=auto_quantize_method, auto_quantize_score_size=auto_quantize_score_size, test_generated=False, compress=compress, auto_quantize_checkpoint=auto_quantize_checkpoint, ) if sparse_cfg: if is_attn_sparsified(model_obj.model): warnings.warn("Skipping sparse attention: model already has sparse attention applied.") else: sparsify_model( model=model_obj, sparse_cfg=sparse_cfg, ) return model_objtests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py (1)
143-144: Tighten sparsity assertion to0 <= stats["sparsity"] <= 1.The formula
sparsity = 1.0 - dense_blocks.item() / total_valid_blockscannot produce negative values sincedense_blocks(a sum of boolean mask) is always between 0 andtotal_valid_blocks. The current assertion allowing-1 <= sparsity <= 1is overly permissive and the comment about negative sparsity is incorrect. Change toassert 0 <= stats["sparsity"] <= 1.examples/llm_sparsity/attention_sparsity/hf_sa.py (1)
93-134: Avoid undefinedgenerated_textwhen CUDA is unavailable.
generate_sample_outputonly setsgenerated_textinside the CUDA branch; if the helper is reused in a CPU-only context, it will raiseUnboundLocalError. Make the GPU requirement explicit (or remove the inner CUDA guard sincemainalready enforces it).🛠 Suggested fix
def generate_sample_output(model, tokenizer, args): - # Load test sample + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + # Load test sample @@ - if torch.cuda.is_available(): - inputs = {k: v.cuda() for k, v in inputs.items()} - - # Generate - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=args.max_new_tokens, - do_sample=args.do_sample, - temperature=args.temperature if args.do_sample else 1.0, - pad_token_id=tokenizer.pad_token_id, - ) - input_length = inputs["input_ids"].shape[1] - generated_ids = outputs[0][input_length:] - generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) + inputs = {k: v.cuda() for k, v in inputs.items()} + # Generate + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature if args.do_sample else 1.0, + pad_token_id=tokenizer.pad_token_id, + ) + input_length = inputs["input_ids"].shape[1] + generated_ids = outputs[0][input_length:] + generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
🤖 Fix all issues with AI agents
In `@examples/llm_sparsity/attention_sparsity/download_ruler_data.sh`:
- Around line 33-45: The loop currently increments count regardless of curl
success; change it so count is incremented only when curl succeeds and failures
are printed—after building raw_url and calling curl in the while loop (the block
using variables url, filename, filepath, raw_url and command curl -fsSL),
capture curl's exit status and only run count=$((count + 1)) when curl exits
successfully, otherwise print a visible failure message including the
url/filepath so failed downloads are surfaced.
In `@examples/llm_sparsity/attention_sparsity/README.md`:
- Around line 9-22: The example in README.md is missing imports required to run
the snippet; add an import for torch and for AutoModelForCausalLM from
transformers at the top so the symbols AutoModelForCausalLM and torch referenced
when loading the model (and later calling mtsa.sparsify with
SKIP_SOFTMAX_DEFAULT) are defined; update the snippet to include these two
imports before importing modelopt.torch.sparsity.attention_sparsity and using
mtsa.sparsify.
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 192-257: The forward_loop in
create_decode_calibration_forward_loop calls torch.cuda.empty_cache()
unconditionally which will fail on CPU-only runs; change it to only call
empty_cache when CUDA is available (e.g., guard with torch.cuda.is_available()
or check device.type == "cuda") after deleting past_key_values so CPU-only
execution doesn't crash while preserving GPU cleanup behavior.
- Around line 135-189: The function create_calibration_forward_loop
(specifically the nested forward_loop) calls torch.cuda.empty_cache() unguarded
in both the chunked prefill branch (after deleting past_key_values) and the full
prefill branch; wrap both calls in a runtime GPU check by only calling
torch.cuda.empty_cache() when torch.cuda.is_available() to avoid AssertionError
on CPU-only PyTorch builds (i.e., add an if torch.cuda.is_available(): guard
around the existing empty_cache() invocations).
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 249-272: Enable calibration should stash the stats-manager's prior
existence and enabled state on the module so we can restore it later; modify
_enable_calibration_mode to record whether module._stats_manager existed and its
previous enabled flag (e.g., store tuple in module._prev_stats_manager_state or
similar) before creating/enabling a SparseAttentionStatsManager and calling
set_calibration_mode(True), and set a marker if you created the manager; then
modify _disable_calibration_mode to consult that stash and restore the original
state: if the manager was newly created by the calibrator, remove it (or set to
None), otherwise restore module._stats_manager.enabled to the previous value and
call set_calibration_mode(False); use the existing symbols
_enable_calibration_mode, _disable_calibration_mode, module._stats_manager,
SparseAttentionStatsManager, and set_calibration_mode to locate and implement
these changes.
In `@modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py`:
- Around line 229-265: The current build_calibration_dataset logic can generate
more than self.total_samples because samples_per_task = max(num_samples //
len(self.subtasks), 1) forces at least one sample per subtask/bin; change the
allocation to cap generation so total produced does not exceed
self.total_samples by computing samples_per_task = num_samples //
len(self.subtasks) and only using the >0 fallback when there is remaining budget
across bins (or by tracking produced_count and breaking when produced_count >=
self.total_samples); update code paths in build_calibration_dataset that use
samples_per_task, the pbar update, and any loop over
self.subtasks/_generate_sample so they stop when the global produced_count
reaches self.total_samples (and ensure pbar.total equals self.total_samples).
In `@modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py`:
- Around line 102-117: The hardcoded DATA_DIR (and derived
RULER_URLS_FILE/ESSAYS_DIR) lives under the package tree which can be read-only;
make the data directory user-configurable by changing _get_data_dir to prefer an
environment variable (e.g., RULER_DATA_DIR or MODELOPT_RULER_DATA_DIR) and fall
back to a user-writable cache location (use platform-appropriate user cache/
config dir), then ensure DATA_DIR, RULER_URLS_FILE and ESSAYS_DIR are resolved
via that helper; create the directory with mkdir(parents=True, exist_ok=True)
and keep existing names for the files so callers using DATA_DIR, _get_data_dir,
RULER_URLS_FILE, and ESSAYS_DIR keep working.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 251-271: The validator validate_target_sparse_ratio currently
allows values equal to 1.0 which can cause divide-by-zero in downstream dynamic
thresholding; update the check in validate_target_sparse_ratio to require 0.0 <=
ratio < 1.0 (strictly less than 1.0) and modify the raised ValueError message to
state the allowed range is [0.0, 1.0). Ensure the error includes the phase and
invalid ratio and still enforces that target_sparse_ratio is a dict with only
"prefill" and "decode" keys.
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 334-347: The _format_threshold function currently only handles
"dynamic" and "static" types, so thresholds labeled "dynamic_calibrated"
(returned by get_threshold_info()) are rendered as N/A; update _format_threshold
to detect t == "dynamic_calibrated", read per-phase objects from
info.get("phases", {}) (each phase entry contains a scale_factor key), build
parts like f"{phase}={scale_factor:.2f}" (fallback to raw value if missing), and
return the same λ={...} formatted string as for "dynamic" so calibrated
per-phase thresholds are displayed correctly.
In `@tests/_test_utils/torch/sparsity/sparse_attention_common.py`:
- Around line 195-197: The assertion incorrectly passes atol as the third
positional argument to torch.allclose (which is rtol); update the call in the
test where torch.allclose(output_sparse, output_restored, atol) is used so that
the tolerance is passed by name (e.g., torch.allclose(output_sparse,
output_restored, atol=atol)) or otherwise supply both rtol and atol positionally
in the correct order; ensure the variables output_sparse, output_restored and
atol are used so the absolute tolerance is applied as intended.
In `@tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py`:
- Around line 44-51: The docstring for test_attention_sparsity incorrectly
claims it tests calibration scenarios but the test isn't parameterized; either
add parametrization for calibration or update the docstring—here: remove the
misleading "(with and without calibration)" text from the
test_attention_sparsity docstring and also remove the redundant seq_len=128
argument from the run_attention_sparsity_command call (the helper defaults to
128); locate test_attention_sparsity and run_attention_sparsity_command in the
test file to apply these edits.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py`:
- Around line 18-28: This GPU test must skip when the optional "transformers"
package is missing because RulerDatasetBuilder loads tokenizers by name at
runtime; add a guard using pytest.importorskip("transformers") (or equivalent
try/except ImportError and pytest.skip) near the top of the test file (same area
as the existing CUDA guard) so tests referencing RulerDatasetBuilder or passing
tokenizer names like "gpt2" won't raise ImportError at runtime.
In
`@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py`:
- Around line 67-79: The tests in test_threshold_validation_range call
SparseAttentionAttributeConfig with scalar thresholds but assert the validator
raises "Threshold must be in range"; since the validator first enforces that
threshold is a dict, update the expectations to match the dict-type validation
error instead of range errors—i.e., in test_threshold_validation_range change
the pytest.raises match from "Threshold must be in range" to the message used by
the dict-type check (e.g., "Threshold must be a dict") for the scalar inputs
when constructing SparseAttentionAttributeConfig.
- Around line 99-102: The test test_threshold_validation_type should assert the
actual validator message from SparseAttentionAttributeConfig rather than a
generic Pydantic message: update the pytest.raises match to the custom validator
text "Threshold must be a dict with 'prefill' and/or 'decode' keys" (or change
the invalid input to a type that triggers the default Pydantic type error) so it
aligns with the validate_threshold validator behavior in
SparseAttentionAttributeConfig.
- Around line 34-46: test_valid_config is passing threshold as a scalar while
the validator validate_threshold on SparseAttentionAttributeConfig expects a
dict with 'prefill' and/or 'decode'; update the test to pass a dict (e.g.
threshold={"prefill": 1e-4, "decode": 1e-4} or at least one of those keys) so
the value satisfies validate_threshold and the assertions remain valid for
SparseAttentionAttributeConfig.
In `@tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py`:
- Around line 18-21: Remove the unnecessary test-level skip by deleting the
pytest.importorskip("transformers") line in the test module; the tests for
SparseAttentionStatsManager do not depend on transformers so simply remove that
import-or-skip call near the top of
tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py and keep the
rest of the file (ensuring SparseAttentionStatsManager tests run on minimal
installs).
In `@tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py`:
- Around line 52-79: The test expects the old threshold-info contract from
FlashSkipSoftmax.get_threshold_info (checking type=="dynamic",
threshold_scale_factor/scale_factors, and formula "λ[phase] / length"), but the
implementation now returns calibration_params and target_sparse_ratio with a
different type/payload; update the tests to assert the new contract: call
FlashSkipSoftmax.get_threshold_info(), assert the returned info uses the new
keys (e.g., info["type"] matches the new value, info["calibration_params"]
contains per-phase calibration data, and info["target_sparse_ratio"] or
equivalent exists), replace assertions that reference
threshold_scale_factor/scale_factors and the old formula with checks against
info["calibration_params"] structure and example thresholds derived from
calibration_params (or alternatively restore legacy fields in
FlashSkipSoftmax.get_threshold_info to emit scale_factors for backward
compatibility). Ensure references to FlashSkipSoftmax.get_threshold_info and the
instance attribute threshold_scale_factor are updated/removed across the other
affected test blocks (lines ~136-170 and ~233-269).
🧹 Nitpick comments (14)
setup.py (1)
60-67: Remove duplicatedeepspeedentry inhfextras.
It appears twice with the same marker; keep a single entry to avoid redundancy and reduce maintenance noise.♻️ Proposed cleanup
"hf": [ "accelerate>=1.0.0", "datasets>=3.0.0", "deepspeed>=0.9.6 ; platform_system != 'Darwin' and platform_system != 'Windows'", "diffusers>=0.32.2", "huggingface_hub>=0.24.0", "peft>=0.17.0", "transformers>=4.53,<5.0", # Should match modelopt/torch/__init__.py and tox.ini - "deepspeed>=0.9.6 ; platform_system != 'Darwin' and platform_system != 'Windows'", "nltk", "wonderwords", ],tests/_test_utils/torch/sparsity/sparse_attention_common.py (1)
164-171: Missing documentation foratolparameter.The docstring lists Args but doesn't document the new
atolparameter.📝 Suggested docstring update
def save_restore_test(model_cls, device, sparse_config, atol=1e-6): """Test save and restore of sparse attention state. Args: model_cls: Model class to test device: Device to run on ('cpu' or 'cuda') sparse_config: Sparse attention configuration + atol: Absolute tolerance for output comparison (default: 1e-6) """tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (1)
253-279: Consider adding a test for decode threshold retrieval.The test validates that
get_threshold_info()returnstype: "static"andvalue: 0.005(the prefill threshold), but it doesn't verify the decode threshold. Consider adding an assertion or separate test to ensure both phase thresholds are accessible via the threshold info API, especially since the PR introduces per-phase threshold handling.examples/llm_eval/sparse_attention_utils.py (2)
63-71: Backend override may discard other top-level config keys.When the backend override is applied, line 71 creates a new config dict containing only
"sparse_cfg". If the originalmtsa_cfgcontained other top-level keys (e.g.,"export_format"fromSparseAttentionConfig), they would be silently discarded.♻️ Suggested fix to preserve other config keys
- mtsa_cfg = {"sparse_cfg": modified_sparse_cfg} + mtsa_cfg = {**mtsa_cfg, "sparse_cfg": modified_sparse_cfg}
21-28: Consider documenting supported wrapper types.The
_extract_modelhelper checks forgpt2andmodelattributes to extract the underlying model. Adding a brief comment listing the known wrapper types (e.g., HFLM, EvalModel) would help maintainability.modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
60-76: Docstring example uses outdated code block syntax.The docstring uses
.. code-block::python(missing space beforepython) which may cause rendering issues in some documentation generators.♻️ Fix code-block directive syntax
- .. code-block::python + .. code-block:: pythonmodelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
57-65: Docstring should match mask value semantics.
Implementations usetorch.finfo(...).min, not literal-inf. Updating the doc avoids confusion.✏️ Suggested doc tweak
- Returns: - Masked attention scores with sparse elements set to -inf + Returns: + Masked attention scores with sparse elements set to dtype minimummodelopt/torch/sparsity/attention_sparsity/conversion.py (1)
244-270: Guard export against missingtarget_sparse_ratio.
Ifcalibration_paramsexists buttarget_sparse_ratioisNone, the exported config may be incomplete for downstream consumers.💡 Suggested guard
- if calibration_params is not None: + if calibration_params is not None and target_sparse_ratio is not None: return { "calibration_params": calibration_params, "target_sparse_ratio": target_sparse_ratio, }modelopt/torch/sparsity/attention_sparsity/stats_manager.py (2)
71-74: Count unexpected phases asunknownto keep stats consistent.
Currently, unexpected phase values are ignored, which can make phase distributions not sum to total calls.🧮 Suggested fallback
- if phase in self.aggregated_stats["phase_counts"]: - self.aggregated_stats["phase_counts"][phase] += 1 + if phase in self.aggregated_stats["phase_counts"]: + self.aggregated_stats["phase_counts"][phase] += 1 + else: + self.aggregated_stats["phase_counts"]["unknown"] += 1
145-147: Avoid exposing mutable internal stats list.
Returning the internal list allows external mutation; a shallow copy is safer.🛡️ Suggested defensive copy
- if phase is None: - return self.per_sample_stats - return [s for s in self.per_sample_stats if s.get("phase") == phase] + if phase is None: + return list(self.per_sample_stats) + return [s.copy() for s in self.per_sample_stats if s.get("phase") == phase]tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py (1)
135-140: No-opforward_loopmeans calibration paths aren’t exercised.
Several tests define apassloop, which can make calibration a no-op and allow regressions to slip through. Consider at least one minimal forward to collect stats.💡 Minimal forward loop example
- def forward_loop(model): - # Simple forward loop for calibration - pass + def forward_loop(model): + model.eval() + with torch.no_grad(): + sample = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=8).cuda() + model(sample)modelopt/torch/sparsity/attention_sparsity/sparse_attention.py (1)
204-215: Guard calibration-mode check for methods that don’t define_calibration_mode.Directly reading a private attribute risks
AttributeErrorif another sparse method hasn’t been updated yet. Agetattr(..., False)keeps this robust without changing behavior.♻️ Suggested tweak
- if not self._sparse_method_instance._calibration_mode: + if not getattr(self._sparse_method_instance, "_calibration_mode", False): input = self._sparse_method_instance.apply_sparsity(input, sparse_mask)modelopt/torch/sparsity/attention_sparsity/config.py (1)
365-406: Consider matching both "attn" and "attention" module names in presets.
"*attn*"won’t match layers that use the full"attention"name. Adding a second pattern avoids silently skipping those models in the preset configs.♻️ Example expansion (apply similarly to SKIP_SOFTMAX_CALIB)
SKIP_SOFTMAX_DEFAULT = { "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", "threshold": { "prefill": 1e-3, # More aggressive during prefill "decode": 1e-4, # Conservative during decode }, "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", # Only pytorch backend supported "collect_stats": True, "enable": True, }, + "*attention*": { + "method": "flash_skip_softmax", + "threshold": { + "prefill": 1e-3, + "decode": 1e-4, + }, + "br": 128, + "bc": 128, + "backend": "pytorch", + "collect_stats": True, + "enable": True, + }, "default": {"enable": False}, }, }modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
25-28: Guard optionaltransformersimport to avoid import-time failures.If
transformersis an optional dependency (e.g., HF extras), importing this module will raise on base installs. Consider lazy-importingAutoTokenizerinside the forward-loop creators (or wrapping with a clear error). Please verify the dependency expectations for base installs.🔧 Example approach
-from transformers import AutoTokenizer +try: + from transformers import AutoTokenizer +except Exception as exc: # transformers optional + AutoTokenizer = None + _TRANSFORMERS_IMPORT_ERROR = excdef create_calibration_forward_loop(...): + if AutoTokenizer is None: + raise ModuleNotFoundError( + "transformers is required for calibration; install nvidia-modelopt[hf]." + ) from _TRANSFORMERS_IMPORT_ERROR tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)def create_decode_calibration_forward_loop(...): + if AutoTokenizer is None: + raise ModuleNotFoundError( + "transformers is required for calibration; install nvidia-modelopt[hf]." + ) from _TRANSFORMERS_IMPORT_ERROR tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
| # Download essays from GitHub URLs | ||
| echo -n "Downloading essays" | ||
| count=0 | ||
| while IFS= read -r url || [ -n "$url" ]; do | ||
| if [[ "${url}" == https://github.com*.txt ]]; then | ||
| filename=$(basename "${url}") | ||
| filepath="${ESSAYS_DIR}/${filename}" | ||
| if [ ! -f "${filepath}" ]; then | ||
| raw_url="${url/github.com/raw.githubusercontent.com}" | ||
| raw_url="${raw_url/\/raw\//\/}" | ||
| curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "." | ||
| count=$((count + 1)) | ||
| fi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only count successful downloads (and surface failures).
Line 43 increments count even if curl fails, so the final count/progress can be inaccurate and failures remain silent.
🛠️ Proposed fix
- curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "."
- count=$((count + 1))
+ if curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null; then
+ echo -n "."
+ count=$((count + 1))
+ else
+ echo "Failed to download ${raw_url}" >&2
+ fi🤖 Prompt for AI Agents
In `@examples/llm_sparsity/attention_sparsity/download_ruler_data.sh` around lines
33 - 45, The loop currently increments count regardless of curl success; change
it so count is incremented only when curl succeeds and failures are
printed—after building raw_url and calling curl in the while loop (the block
using variables url, filename, filepath, raw_url and command curl -fsSL),
capture curl's exit status and only run count=$((count + 1)) when curl exits
successfully, otherwise print a visible failure message including the
url/filepath so failed downloads are surfaced.
| ```python | ||
| import modelopt.torch.sparsity.attention_sparsity as mtsa | ||
| from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT | ||
|
|
||
| # Load your model | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| "Qwen/Qwen3-8B", | ||
| attn_implementation="eager", # Required for sparse attention | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| # Apply sparse attention | ||
| model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing imports in the quick example.
The code example is missing necessary imports for AutoModelForCausalLM and torch, which would cause the example to fail if copied as-is.
📝 Suggested fix
```python
+import torch
+from transformers import AutoModelForCausalLM
+
import modelopt.torch.sparsity.attention_sparsity as mtsa
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT🤖 Prompt for AI Agents
In `@examples/llm_sparsity/attention_sparsity/README.md` around lines 9 - 22, The
example in README.md is missing imports required to run the snippet; add an
import for torch and for AutoModelForCausalLM from transformers at the top so
the symbols AutoModelForCausalLM and torch referenced when loading the model
(and later calling mtsa.sparsify with SKIP_SOFTMAX_DEFAULT) are defined; update
the snippet to include these two imports before importing
modelopt.torch.sparsity.attention_sparsity and using mtsa.sparsify.
| def create_calibration_forward_loop( | ||
| calibration_data: list[dict[str, Any]], | ||
| tokenizer_name_or_path: str, | ||
| batch_size: int = 1, | ||
| chunk_size: int = 2048, | ||
| ) -> Callable: | ||
| """Create forward loop for calibration. | ||
|
|
||
| Args: | ||
| calibration_data: List of samples with 'input' and 'length' fields | ||
| tokenizer_name_or_path: HuggingFace tokenizer path | ||
| batch_size: Batch size (currently unused, always 1) | ||
| chunk_size: Chunk size for chunked prefill to avoid OOM. Set to -1 to disable. | ||
|
|
||
| Returns: | ||
| Forward loop function that takes model as argument | ||
| """ | ||
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) | ||
| if not tokenizer.pad_token: | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
||
| def forward_loop(model: nn.Module) -> None: | ||
| device = next(model.parameters()).device | ||
|
|
||
| for sample in calibration_data: | ||
| inputs = tokenizer( | ||
| sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] | ||
| ) | ||
| inputs = {k: v.to(device) for k, v in inputs.items()} | ||
| input_ids = inputs["input_ids"].to(device) | ||
| seq_len = input_ids.shape[1] | ||
|
|
||
| with torch.no_grad(): | ||
| if chunk_size > 0 and seq_len > chunk_size: | ||
| # Chunked prefill to avoid OOM with long sequences | ||
| past_key_values = None | ||
| for start_idx in range(0, seq_len, chunk_size): | ||
| end_idx = min(start_idx + chunk_size, seq_len) | ||
| chunk_input_ids = input_ids[:, start_idx:end_idx] | ||
|
|
||
| outputs = model( | ||
| chunk_input_ids, | ||
| past_key_values=past_key_values, | ||
| use_cache=True, | ||
| ) | ||
| past_key_values = outputs.past_key_values | ||
|
|
||
| # Clean up KV cache | ||
| del past_key_values | ||
| torch.cuda.empty_cache() | ||
| else: | ||
| # Full prefill without chunking | ||
| model(input_ids, use_cache=False) | ||
|
|
||
| return forward_loop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
torch.cuda.empty_cache() behavior without CUDA available
💡 Result:
torch.cuda.empty_cache() is only meaningful once PyTorch’s CUDA runtime + caching allocator are initialized; otherwise it won’t be able to “free GPU cache”.
What happens when “CUDA isn’t available” depends on why it isn’t available:
- CPU-only PyTorch build (not compiled with CUDA)
- Calling
torch.cuda.empty_cache()will trigger CUDA lazy-init and typically raisesAssertionError: Torch not compiled with CUDA enabled(coming fromtorch.cuda._lazy_init). [3]
- CUDA build, but CUDA can’t be initialized at runtime (e.g., no driver / driver error / WSL/NVML issues)
torch.cuda.is_available()may beFalse, and attempting CUDA init (whichempty_cache()effectively forces) can raise aRuntimeErrorfromcudaGetDeviceCount()/ initialization. [4]
- CUDA build + working CUDA, but you just haven’t used CUDA yet
torch.cuda.empty_cache()may still initialize CUDA (creating a context) and then release unoccupied cached blocks held by PyTorch’s caching allocator. It does not free memory held by live tensors. [1][2]
Practical pattern:
import torch
if torch.cuda.is_available() and torch.cuda.is_initialized():
torch.cuda.empty_cache()(Use is_available() to avoid CPU-only / broken-CUDA cases, and is_initialized() if you specifically want to avoid initializing CUDA just to call empty_cache().) [1][2]
🏁 Script executed:
# Check if there are other CUDA-related calls in the file
rg -A 2 -B 2 "torch.cuda" modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pyRepository: NVIDIA/Model-Optimizer
Length of output: 384
🏁 Script executed:
# Search for CPU-related checks or CUDA availability guards in the codebase
rg "torch.cuda.is_available|cuda.is_available" --type py -A 2 | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 4501
🏁 Script executed:
# Check if this calibration module has any documentation about CPU support
rg -B 5 -A 10 "def create_calibration_forward_loop" modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pyRepository: NVIDIA/Model-Optimizer
Length of output: 526
Guard CUDA cache clearing on CPU-only builds.
torch.cuda.empty_cache() raises AssertionError on CPU-only PyTorch builds. There are two unguarded calls in this function (after chunked and full prefill paths); both should be guarded with if torch.cuda.is_available():. This pattern is already established elsewhere in the codebase.
🛠️ Suggested change
- torch.cuda.empty_cache()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()Apply the same guard to the second instance after the full prefill path as well.
🤖 Prompt for AI Agents
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 135 - 189, The function create_calibration_forward_loop (specifically the
nested forward_loop) calls torch.cuda.empty_cache() unguarded in both the
chunked prefill branch (after deleting past_key_values) and the full prefill
branch; wrap both calls in a runtime GPU check by only calling
torch.cuda.empty_cache() when torch.cuda.is_available() to avoid AssertionError
on CPU-only PyTorch builds (i.e., add an if torch.cuda.is_available(): guard
around the existing empty_cache() invocations).
| def create_decode_calibration_forward_loop( | ||
| calibration_data: list[dict[str, Any]], | ||
| tokenizer_name_or_path: str, | ||
| num_decode_tokens: int = 10, | ||
| ) -> Callable: | ||
| """Create forward loop for decode phase calibration. | ||
|
|
||
| Uses flash attention for fast prefill, then switches to eager attention | ||
| for decode token generation with softmax hook measurement. | ||
|
|
||
| Args: | ||
| calibration_data: List of samples with 'input' and 'length' fields | ||
| tokenizer_name_or_path: HuggingFace tokenizer path | ||
| num_decode_tokens: Number of decode tokens to generate per sample | ||
|
|
||
| Returns: | ||
| Forward loop function that takes model as argument | ||
| """ | ||
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) | ||
| if not tokenizer.pad_token: | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
||
| def forward_loop(model: nn.Module) -> None: | ||
| device = next(model.parameters()).device | ||
|
|
||
| for sample in calibration_data: | ||
| inputs = tokenizer( | ||
| sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] | ||
| ) | ||
| input_ids = inputs["input_ids"].to(device) | ||
|
|
||
| # Save original attention implementation | ||
| original_attn_impl = getattr(model.config, "_attn_implementation", "eager") | ||
|
|
||
| with torch.no_grad(): | ||
| try: | ||
| # Step 1: Fast prefill with flash attention (no measurement) | ||
| model.config._attn_implementation = "flash_attention_2" | ||
| outputs = model(input_ids, use_cache=True) | ||
| past_key_values = outputs.past_key_values | ||
|
|
||
| # Step 2: Switch to eager for decode (enables softmax hook) | ||
| model.config._attn_implementation = "eager" | ||
|
|
||
| # Step 3: Manual decode loop for explicit control over token generation | ||
| # model.generate() method is not used here because it doesn't allow explicit control over KV cache | ||
| # Get the last token's logits and sample next token | ||
| next_token = outputs.logits[:, -1:, :].argmax(dim=-1) | ||
|
|
||
| for _ in range(num_decode_tokens): | ||
| outputs = model( | ||
| next_token, | ||
| past_key_values=past_key_values, | ||
| use_cache=True, | ||
| ) | ||
| past_key_values = outputs.past_key_values | ||
| next_token = outputs.logits[:, -1:, :].argmax(dim=-1) | ||
| finally: | ||
| # Restore original attention implementation | ||
| model.config._attn_implementation = original_attn_impl | ||
|
|
||
| # Clean up | ||
| del past_key_values | ||
| torch.cuda.empty_cache() | ||
|
|
||
| return forward_loop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py | sed -n '192,257p'Repository: NVIDIA/Model-Optimizer
Length of output: 3361
🏁 Script executed:
rg "torch\.cuda" modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pyRepository: NVIDIA/Model-Optimizer
Length of output: 147
🏁 Script executed:
rg "torch\.cuda\.is_available" --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 2206
🏁 Script executed:
rg "empty_cache" --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 1942
🏁 Script executed:
# Check the context around the first empty_cache call in calibrate.py
sed -n '150,180p' modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1436
🏁 Script executed:
# Check if there are guards around empty_cache elsewhere in the same file
rg -B3 -A3 "torch\.cuda\.empty_cache\(\)" modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pyRepository: NVIDIA/Model-Optimizer
Length of output: 441
🏁 Script executed:
# Check how empty_cache is used in similar calibration contexts
rg -B5 -A5 "torch\.cuda\.empty_cache\(\)" modelopt/torch/utils/dataset_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 489
🏁 Script executed:
# Check the module's scope and whether it requires CUDA
head -50 modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1945
Guard torch.cuda.empty_cache() to support CPU-only execution.
The torch.cuda.empty_cache() call at line 255 will crash on CPU-only builds since the model device can be CPU. This must be guarded with torch.cuda.is_available() to ensure the decode calibration function works across all deployment scenarios.
Suggested change
del past_key_values
- torch.cuda.empty_cache()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()🤖 Prompt for AI Agents
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 192 - 257, The forward_loop in create_decode_calibration_forward_loop
calls torch.cuda.empty_cache() unconditionally which will fail on CPU-only runs;
change it to only call empty_cache when CUDA is available (e.g., guard with
torch.cuda.is_available() or check device.type == "cuda") after deleting
past_key_values so CPU-only execution doesn't crash while preserving GPU cleanup
behavior.
| def _enable_calibration_mode(self, modules: list[nn.Module]): | ||
| """Enable calibration mode on sparse attention modules.""" | ||
| for idx, module in enumerate(modules): | ||
| # Create stats manager if needed | ||
| if not module._stats_manager: | ||
| module._stats_manager = SparseAttentionStatsManager( | ||
| module_name=f"sparse_attn_{idx}", enabled=True | ||
| ) | ||
| else: | ||
| # Re-enable if disabled | ||
| module._stats_manager.enabled = True | ||
|
|
||
| # Enable calibration mode with fresh stats | ||
| module._stats_manager.set_calibration_mode(enabled=True, reset_history=True) | ||
| module._sparse_method_instance.set_calibration_mode(True) | ||
|
|
||
| def _disable_calibration_mode(self, modules: list[nn.Module]): | ||
| """Disable calibration mode (but keep stats enabled if collect_stats=True).""" | ||
| for module in modules: | ||
| if module._stats_manager: | ||
| module._stats_manager.set_calibration_mode(enabled=False) | ||
|
|
||
| module._sparse_method_instance.set_calibration_mode(False) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore stats-manager state after calibration to avoid permanent overhead.
_enable_calibration_mode creates/enables a stats manager even when collect_stats was off, but _disable_calibration_mode never restores the previous state. That leaves stats collection enabled after calibration, changing runtime behavior and adding overhead. Consider stashing and restoring the prior stats-manager state.
🛠 Suggested fix
def _enable_calibration_mode(self, modules: list[nn.Module]):
"""Enable calibration mode on sparse attention modules."""
for idx, module in enumerate(modules):
+ # Preserve previous stats-manager state for restoration
+ if not hasattr(module, "_calib_prev_stats_state"):
+ module._calib_prev_stats_state = (
+ module._stats_manager,
+ module._stats_manager.enabled if module._stats_manager else None,
+ )
# Create stats manager if needed
if not module._stats_manager:
module._stats_manager = SparseAttentionStatsManager(
module_name=f"sparse_attn_{idx}", enabled=True
)
@@
def _disable_calibration_mode(self, modules: list[nn.Module]):
"""Disable calibration mode (but keep stats enabled if collect_stats=True)."""
for module in modules:
if module._stats_manager:
module._stats_manager.set_calibration_mode(enabled=False)
module._sparse_method_instance.set_calibration_mode(False)
+
+ prev_mgr, prev_enabled = getattr(module, "_calib_prev_stats_state", (None, None))
+ if prev_mgr is None:
+ module._stats_manager = None
+ else:
+ module._stats_manager = prev_mgr
+ if prev_enabled is not None:
+ module._stats_manager.enabled = prev_enabled
+ if hasattr(module, "_calib_prev_stats_state"):
+ delattr(module, "_calib_prev_stats_state")🤖 Prompt for AI Agents
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 249 - 272, Enable calibration should stash the stats-manager's prior
existence and enabled state on the module so we can restore it later; modify
_enable_calibration_mode to record whether module._stats_manager existed and its
previous enabled flag (e.g., store tuple in module._prev_stats_manager_state or
similar) before creating/enabling a SparseAttentionStatsManager and calling
set_calibration_mode(True), and set a marker if you created the manager; then
modify _disable_calibration_mode to consult that stash and restore the original
state: if the manager was newly created by the calibrator, remove it (or set to
None), otherwise restore module._stats_manager.enabled to the previous value and
call set_calibration_mode(False); use the existing symbols
_enable_calibration_mode, _disable_calibration_mode, module._stats_manager,
SparseAttentionStatsManager, and set_calibration_mode to locate and implement
these changes.
| def test_valid_config(self): | ||
| """Test creating valid config.""" | ||
| config = SparseAttentionAttributeConfig( | ||
| method="flash_skip_softmax", | ||
| threshold=1e-4, | ||
| br=128, | ||
| bc=128, | ||
| enable=True, | ||
| ) | ||
| assert config.method == "flash_skip_softmax" | ||
| assert config.threshold == 1e-4 | ||
| assert config.br == 128 | ||
| assert config.bc == 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, locate the config.py file and examine the threshold validator
fd config.py -x cat -n {} | head -150Repository: NVIDIA/Model-Optimizer
Length of output: 7093
🏁 Script executed:
# Search for the validate_threshold validator
rg -n "validate_threshold|threshold" . -A 3 -B 1 --type py | grep -E "(config\.py|sparse)" | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 10117
🏁 Script executed:
# Look at the specific lines in config.py mentioned in the review
find . -name "config.py" -path "*/sparse*" -o -name "config.py" -path "*/attention*" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 120
🏁 Script executed:
# Examine the config.py file to see the threshold field definition and validator
cat ./modelopt/torch/sparsity/attention_sparsity/config.pyRepository: NVIDIA/Model-Optimizer
Length of output: 15398
Test passes scalar threshold but validator requires dict format.
The test_valid_config test passes threshold=1e-4 (a scalar float), but the validate_threshold validator in config.py (lines 121-142) explicitly requires a dict with 'prefill' and/or 'decode' keys. The validator checks isinstance(v, dict) and raises ValueError if the input is not a dict. This test will fail with a ValidationError at runtime.
🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py`
around lines 34 - 46, test_valid_config is passing threshold as a scalar while
the validator validate_threshold on SparseAttentionAttributeConfig expects a
dict with 'prefill' and/or 'decode'; update the test to pass a dict (e.g.
threshold={"prefill": 1e-4, "decode": 1e-4} or at least one of those keys) so
the value satisfies validate_threshold and the assertions remain valid for
SparseAttentionAttributeConfig.
| def test_threshold_validation_range(self): | ||
| """Test threshold must be in range (0, 1).""" | ||
| with pytest.raises(ValidationError, match="Threshold must be in range"): | ||
| SparseAttentionAttributeConfig(threshold=0) | ||
|
|
||
| with pytest.raises(ValidationError, match="Threshold must be in range"): | ||
| SparseAttentionAttributeConfig(threshold=-0.1) | ||
|
|
||
| with pytest.raises(ValidationError, match="Threshold must be in range"): | ||
| SparseAttentionAttributeConfig(threshold=1.0) | ||
|
|
||
| with pytest.raises(ValidationError, match="Threshold must be in range"): | ||
| SparseAttentionAttributeConfig(threshold=1.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Threshold validation tests pass scalar values but validator expects dict.
Similar to the earlier issue, these tests pass scalar threshold values (0, -0.1, 1.0, 1.5) expecting range validation errors, but the validator should first fail with "Threshold must be a dict" before checking the range. The test match strings expect "Threshold must be in range" which won't be reached if the dict type check fails first.
🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py`
around lines 67 - 79, The tests in test_threshold_validation_range call
SparseAttentionAttributeConfig with scalar thresholds but assert the validator
raises "Threshold must be in range"; since the validator first enforces that
threshold is a dict, update the expectations to match the dict-type validation
error instead of range errors—i.e., in test_threshold_validation_range change
the pytest.raises match from "Threshold must be in range" to the message used by
the dict-type check (e.g., "Threshold must be a dict") for the scalar inputs
when constructing SparseAttentionAttributeConfig.
| def test_threshold_validation_type(self): | ||
| """Test threshold type validation.""" | ||
| with pytest.raises(ValidationError, match="Input should be a valid"): | ||
| SparseAttentionAttributeConfig(threshold="invalid") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test expects wrong error message for string threshold.
The test passes threshold="invalid" and expects a match on "Input should be a valid". However, based on the custom validate_threshold validator, the error message should be "Threshold must be a dict with 'prefill' and/or 'decode' keys". The current match pattern likely relies on Pydantic's default type validation running before the custom validator, which may not be the case.
🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py`
around lines 99 - 102, The test test_threshold_validation_type should assert the
actual validator message from SparseAttentionAttributeConfig rather than a
generic Pydantic message: update the pytest.raises match to the custom validator
text "Threshold must be a dict with 'prefill' and/or 'decode' keys" (or change
the invalid input to a type that triggers the default Pydantic type error) so it
aligns with the validate_threshold validator behavior in
SparseAttentionAttributeConfig.
| import pytest | ||
|
|
||
| pytest.importorskip("transformers") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ Verification inconclusive
Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.pyRepository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
1 # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 # SPDX-License-Identifier: Apache-2.0
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 """Unit tests for SparseAttentionStatsManager."""
17
18 import pytest
19
20 pytest.importorskip("transformers")
21
22 from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager
23
24
25 class TestStatsManagerInitialization:
26 """Test stats manager initialization."""
27
28 def test_initialization_defaults(self):
29 """Test default initialization."""
30 manager = SparseAttentionStatsManager(module_name="test_module")
31
32 assert manager.module_name == "test_module"
33 assert manager.enabled is True
34 assert manager.calibration_mode is False
35 assert manager.aggregated_stats["total_calls"] == 0
36 assert manager.aggregated_stats["total_blocks"] == 0
37 assert manager.aggregated_stats["sparse_blocks"] == 0
38 assert manager.per_sample_stats == []
39
40 def test_initialization_disabled(self):
41 """Test initialization with disabled stats."""
42 manager = SparseAttentionStatsManager(module_name="test_module", enabled=False)
43
44 assert manager.enabled is False
45 assert manager.calibration_mode is False
46
47 def test_initialization_custom_name(self):
48 """Test initialization with custom module name."""
49 manager = SparseAttentionStatsManager(module_name="custom.attention.module")
50
51 assert manager.module_name == "custom.attention.module"
52
53
54 class TestStatsCollection:
55 """Test statistics collection functionality."""
56
57 def test_collect_stats_enabled(self):
58 """Test collecting stats when enabled."""
59 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
60
61 stats = {
62 "sparsity": 0.5,
63 "phase": "prefill",
64 "total_blocks": 100,
65 "sparse_blocks": 50,
66 "sample_length": 1024,
67 }
68
69 manager.collect(stats)
70
71 assert manager.aggregated_stats["total_calls"] == 1
72 assert manager.aggregated_stats["total_blocks"] == 100
73 assert manager.aggregated_stats["sparse_blocks"] == 50
74 assert manager.aggregated_stats["phase_counts"]["prefill"] == 1
75 assert manager.aggregated_stats["phase_counts"]["decode"] == 0
76
77 def test_collect_stats_disabled(self):
78 """Test that collect() is no-op when disabled."""
79 manager = SparseAttentionStatsManager(module_name="test", enabled=False)
80
81 stats = {
82 "sparsity": 0.5,
83 "phase": "prefill",
84 "total_blocks": 100,
85 "sparse_blocks": 50,
86 }
87
88 manager.collect(stats)
89
90 # Should remain at initial values
91 assert manager.aggregated_stats["total_calls"] == 0
92 assert manager.aggregated_stats["total_blocks"] == 0
93 assert manager.aggregated_stats["sparse_blocks"] == 0
94
95 def test_collect_multiple_calls(self):
96 """Test accumulation over multiple collect calls."""
97 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
98
99 # Collect multiple times
100 for i in range(5):
101 stats = {
102 "sparsity": 0.5,
103 "phase": "prefill",
104 "total_blocks": 100,
105 "sparse_blocks": 50,
106 }
107 manager.collect(stats)
108
109 assert manager.aggregated_stats["total_calls"] == 5
110 assert manager.aggregated_stats["total_blocks"] == 500
111 assert manager.aggregated_stats["sparse_blocks"] == 250
112 assert manager.aggregated_stats["phase_counts"]["prefill"] == 5
113
114 def test_collect_different_phases(self):
115 """Test phase counting."""
116 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
117
118 # Collect prefill stats
119 manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50})
120 manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50})
121
122 # Collect decode stats
123 manager.collect({"phase": "decode", "total_blocks": 10, "sparse_blocks": 5})
124
125 assert manager.aggregated_stats["phase_counts"]["prefill"] == 2
126 assert manager.aggregated_stats["phase_counts"]["decode"] == 1
127 assert manager.aggregated_stats["phase_counts"]["unknown"] == 0
128
129
130 class TestCalibrationMode:
131 """Test calibration mode functionality."""
132
133 def test_calibration_mode_per_sample_collection(self):
134 """Test that calibration mode stores per-sample stats."""
135 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
136
137 # Enable calibration mode
138 manager.set_calibration_mode(enabled=True)
139
140 stats = {
141 "sparsity": 0.5,
142 "phase": "prefill",
143 "total_blocks": 100,
144 "sparse_blocks": 50,
145 "sample_length": 1024,
146 }
147
148 manager.collect(stats)
149
150 # Should store in per_sample_stats
151 assert len(manager.per_sample_stats) == 1
152 assert manager.per_sample_stats[0]["module"] == "test"
153 assert manager.per_sample_stats[0]["sparsity"] == 0.5
154 assert manager.per_sample_stats[0]["sample_length"] == 1024
155 assert manager.per_sample_stats[0]["phase"] == "prefill"
156
157 def test_calibration_mode_off(self):
158 """Test that per-sample stats are not collected when calibration mode is off."""
159 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
160 # Calibration mode is off by default
161
162 stats = {"sparsity": 0.5, "phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}
163
164 manager.collect(stats)
165
166 # Should NOT store in per_sample_stats
167 assert len(manager.per_sample_stats) == 0
168
169 # But should still aggregate
170 assert manager.aggregated_stats["total_calls"] == 1
171
172 def test_set_calibration_mode_with_reset(self):
173 """Test set_calibration_mode with reset_history=True."""
174 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
175
176 # Collect some stats in calibration mode
177 manager.set_calibration_mode(enabled=True)
178 manager.collect(
179 {
180 "sparsity": 0.5,
181 "phase": "prefill",
182 "total_blocks": 100,
183 "sparse_blocks": 50,
184 "sample_length": 1024,
185 }
186 )
187 assert len(manager.per_sample_stats) == 1
188
189 # Re-enable with reset
190 manager.set_calibration_mode(enabled=True, reset_history=True)
191 assert len(manager.per_sample_stats) == 0 # Should be cleared
192
193 def test_set_calibration_mode_without_reset(self):
194 """Test set_calibration_mode with reset_history=False."""
195 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
196
197 # Collect some stats
198 manager.set_calibration_mode(enabled=True)
199 manager.collect(
200 {
201 "sparsity": 0.5,
202 "phase": "prefill",
203 "total_blocks": 100,
204 "sparse_blocks": 50,
205 "sample_length": 1024,
206 }
207 )
208 assert len(manager.per_sample_stats) == 1
209
210 # Disable without reset
211 manager.set_calibration_mode(enabled=False, reset_history=False)
212 assert len(manager.per_sample_stats) == 1 # Should be preserved
213
214
215 class TestGetSummary:
216 """Test get_summary() functionality."""
217
218 def test_get_summary_with_data(self):
219 """Test get_summary returns correct averages."""
220 manager = SparseAttentionStatsManager(module_name="test_module", enabled=True)
221
222 # Collect stats
223 manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 30})
224 manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50})
225
226 summary = manager.get_summary()
227
228 assert summary["module"] == "test_module"
229 assert summary["total_calls"] == 2
230 # Average sparsity: (30+50) / (100+100) = 80/200 = 0.4
231 assert summary["average_sparsity"] == 0.4
232 assert summary["phase_distribution"]["prefill"] == 2
233
234 def test_get_summary_no_data(self):
235 """Test get_summary with no collected data."""
236 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
237
238 summary = manager.get_summary()
239
240 assert summary["module"] == "test"
241 assert summary["total_calls"] == 0
242 assert summary["average_sparsity"] == 0.0
243 assert summary["phase_distribution"]["prefill"] == 0
244
245 def test_get_summary_zero_blocks(self):
246 """Test get_summary when total_blocks is zero."""
247 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
248
249 # Collect stats with zero blocks
250 manager.collect({"phase": "prefill", "total_blocks": 0, "sparse_blocks": 0})
251
252 summary = manager.get_summary()
253
254 assert summary["average_sparsity"] == 0.0 # Should handle division by zero
255
256
257 class TestGetCalibrationStats:
258 """Test get_calibration_stats() functionality."""
259
260 def test_get_calibration_stats(self):
261 """Test retrieving per-sample calibration stats."""
262 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
263 manager.set_calibration_mode(enabled=True)
264
265 # Collect multiple samples
266 for i in range(3):
267 manager.collect(
268 {
269 "sparsity": 0.3 + i * 0.1,
270 "phase": "prefill",
271 "total_blocks": 100,
272 "sparse_blocks": 30,
273 "sample_length": 1024 + i * 512,
274 }
275 )
276
277 calib_stats = manager.get_calibration_stats()
278
279 assert len(calib_stats) == 3
280 assert calib_stats[0]["sparsity"] == 0.3
281 assert calib_stats[1]["sparsity"] == 0.4
282 assert calib_stats[2]["sparsity"] == 0.5
283
284 def test_get_calibration_stats_empty(self):
285 """Test get_calibration_stats when no calibration data."""
286 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
287
288 calib_stats = manager.get_calibration_stats()
289
290 assert calib_stats == []
291
292
293 class TestReset:
294 """Test reset functionality."""
295
296 def test_reset(self):
297 """Test reset() clears all statistics."""
298 manager = SparseAttentionStatsManager(module_name="test", enabled=True)
299 manager.set_calibration_mode(enabled=True)
300
301 # Collect some stats
302 manager.collect(
303 {
304 "sparsity": 0.5,
305 "phase": "prefill",
306 "total_blocks": 100,
307 "sparse_blocks": 50,
308 "sample_length": 1024,
309 }
310 )
311 manager.collect(
312 {
313 "sparsity": 0.3,
314 "phase": "decode",
315 "total_blocks": 10,
316 "sparse_blocks": 3,
317 "sample_length": 128,
318 }
319 )
320
321 # Verify stats exist
322 assert manager.aggregated_stats["total_calls"] == 2
323 assert len(manager.per_sample_stats) == 2
324
325 # Reset
326 manager.reset()
327
328 # All stats should be cleared
329 assert manager.aggregated_stats["total_calls"] == 0
330 assert manager.aggregated_stats["total_blocks"] == 0
331 assert manager.aggregated_stats["sparse_blocks"] == 0
332 assert manager.per_sample_stats == []
333 assert manager.aggregated_stats["phase_counts"]["prefill"] == 0
334 assert manager.aggregated_stats["phase_counts"]["decode"] == 0
Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/stats_manager.py | head -100Repository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
1 # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 # SPDX-License-Identifier: Apache-2.0
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 """Statistics manager for sparse attention modules."""
17
18
19 class SparseAttentionStatsManager:
20 """Centralized statistics manager for sparse attention.
21
22 This class is the single source of truth for all statistics collection
23 in sparse attention modules. It handles both runtime aggregation and
24 per-sample calibration statistics.
25
26 Design principles:
27 - Single responsibility: only stats management
28 - No computation: receives pre-computed stats from methods
29 - Optional: can be None if stats collection disabled
30 - Zero overhead when disabled
31 """
32
33 def __init__(self, module_name: str, enabled: bool = True):
34 """Initialize stats manager.
35
36 Args:
37 module_name: Name of the module this manager is attached to
38 enabled: Whether stats collection is enabled
39 """
40 self.module_name = module_name
41 self.enabled = enabled
42 self.calibration_mode = False
43
44 # Aggregated stats (running totals across all forward passes)
45 self.aggregated_stats: dict = {
46 "total_calls": 0,
47 "total_blocks": 0,
48 "sparse_blocks": 0,
49 "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0},
50 }
51
52 # Per-sample stats (only populated during calibration)
53 self.per_sample_stats: list[dict] = []
54
55 def collect(self, stats: dict):
56 """Collect statistics from a single forward pass.
57
58 Args:
59 stats: Dictionary containing statistics from method computation.
60 Expected keys: sparsity, phase, total_blocks, sparse_blocks,
61 sample_length (optional)
62 """
63 if not self.enabled:
64 return
65
66 # Update aggregated stats
67 self.aggregated_stats["total_calls"] += 1
68 self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0)
69 self.aggregated_stats["sparse_blocks"] += stats.get("sparse_blocks", 0)
70
71 phase = stats.get("phase", "unknown")
72 if phase in self.aggregated_stats["phase_counts"]:
73 self.aggregated_stats["phase_counts"][phase] += 1
74
75 # In calibration mode, store per-sample stats
76 if self.calibration_mode:
77 self.per_sample_stats.append(
78 {
79 "module": self.module_name,
80 "sparsity": stats.get("sparsity", 0.0),
81 "sample_length": stats.get("sample_length", 0),
82 "phase": phase,
83 }
84 )
85
86 def get_summary(self) -> dict:
87 """Get aggregated statistics summary.
88
89 Returns:
90 Dictionary with module name, total calls, average sparsity,
91 and phase distribution.
92 """
93 total_blocks = self.aggregated_stats["total_blocks"]
94 if total_blocks > 0:
95 avg_sparsity = self.aggregated_stats["sparse_blocks"] / total_blocks
96 else:
97 avg_sparsity = 0.0
98
99 return {
100 "module": self.module_name,
Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/stats_manager.py | tail -50Repository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
98
99 return {
100 "module": self.module_name,
101 "total_calls": self.aggregated_stats["total_calls"],
102 "average_sparsity": avg_sparsity,
103 "phase_distribution": self.aggregated_stats["phase_counts"].copy(),
104 }
105
106 def set_calibration_mode(self, enabled: bool, reset_history: bool = True):
107 """Enable or disable calibration mode.
108
109 In calibration mode, per-sample statistics are stored for detailed
110 analysis. Otherwise, only aggregated stats are maintained.
111
112 Args:
113 enabled: Whether to enable calibration mode
114 reset_history: Whether to clear per_sample_stats when enabling
115 """
116 self.calibration_mode = enabled
117 if enabled and reset_history:
118 self.per_sample_stats = []
119
120 def reset(self):
121 """Reset all statistics to initial state."""
122 self.aggregated_stats = {
123 "total_calls": 0,
124 "total_blocks": 0,
125 "sparse_blocks": 0,
126 "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0},
127 }
128 self.per_sample_stats = []
129
130 def get_calibration_stats(self, phase: str | None = None) -> list[dict]:
131 """Get per-sample calibration statistics, optionally filtered by phase.
132
133 Note: Returns historical stats collected while calibration_mode was enabled.
134 Stats remain accessible even after calibration_mode is disabled.
135 New stats are only collected when calibration_mode is True.
136
137 Args:
138 phase: Optional phase to filter by ('prefill' or 'decode').
139 If None, returns all stats.
140
141 Returns:
142 List of per-sample statistics dictionaries.
143 Empty list if no stats were collected or no stats match the phase.
144 """
145 if phase is None:
146 return self.per_sample_stats
147 return [s for s in self.per_sample_stats if s.get("phase") == phase]
Script executed:
# Check if transformers is used anywhere in this test file
rg "transformers" tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.pyRepository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
pytest.importorskip("transformers")
Script executed:
# Check if transformers is imported or used in the stats_manager.py itself
rg "transformers" modelopt/torch/sparsity/attention_sparsity/Repository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py: from transformers import AutoModelForCausalLM
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py:import transformers
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py: if isinstance(model, transformers.PreTrainedModel):
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py: if not isinstance(model, transformers.PreTrainedModel):
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py:from transformers import AutoTokenizer
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py: from transformers import AutoTokenizer
Script executed:
# Check if there are other test files in the same directory and their importorskip patterns
find tests/unit/torch/sparsity/attention_sparsity/ -name "*.py" -type f | xargs -I {} sh -c 'echo "=== {} ===" && head -25 {}'Repository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
=== tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for sparse attention mode registry."""
import pytest
pytest.importorskip("transformers")
from modelopt.torch.opt.mode import _ModeRegistryCls
from modelopt.torch.sparsity.attention_sparsity.mode import SparseAttentionModeRegistry
=== tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for FlashSkipSoftmax method internals."""
import pytest
import torch
pytest.importorskip("transformers")
from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax
=== tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for sparse attention calibration."""
import pytest
pytest.importorskip("transformers")
import numpy as np
from _test_utils.torch.sparsity.sparse_attention_common import (
SimpleAttentionModel,
SimpleTransformerEncoder,
=== tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test sparse attention configuration validation."""
import pytest
from pydantic import ValidationError
pytest.importorskip("transformers")
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_DEFAULT,
FlashSkipSoftmaxConfig,
=== tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for sparse attention conversion and replacement."""
import pytest
pytest.importorskip("transformers")
import torch.nn as nn
from _test_utils.torch.sparsity.sparse_attention_common import (
FLASH_SKIP_SOFTMAX_DEFAULT_CFG,
SimpleAttentionModel,
=== tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for SparseAttentionStatsManager."""
import pytest
pytest.importorskip("transformers")
from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager
class TestStatsManagerInitialization:
=== tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for threshold calibration functionality."""
import pytest
pytest.importorskip("transformers")
from _test_utils.torch.sparsity.sparse_attention_common import SimpleAttentionModel
from modelopt.torch.sparsity.attention_sparsity import sparsify
from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax
Remove unnecessary transformers import skip.
The pytest.importorskip("transformers") is unnecessary—SparseAttentionStatsManager and its tests have no dependency on transformers. This skip will prevent test execution on minimal installs without providing any benefit. Remove it.
Suggested change
import pytest
-pytest.importorskip("transformers")
from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import pytest | |
| pytest.importorskip("transformers") | |
| import pytest | |
| from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager |
🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py` around
lines 18 - 21, Remove the unnecessary test-level skip by deleting the
pytest.importorskip("transformers") line in the test module; the tests for
SparseAttentionStatsManager do not depend on transformers so simply remove that
import-or-skip call near the top of
tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py and keep the
rest of the file (ensuring SparseAttentionStatsManager tests run on minimal
installs).
| def test_dynamic_calibrated_threshold(self): | ||
| """Test threshold info for calibrated dynamic threshold.""" | ||
| method = FlashSkipSoftmax( | ||
| method_config={ | ||
| "threshold": {"prefill": 0.001, "decode": 0.0001}, | ||
| "br": 128, | ||
| "bc": 128, | ||
| "backend": "pytorch", | ||
| "is_causal": True, | ||
| } | ||
| ) | ||
|
|
||
| # Simulate calibration setting per-phase scale factors | ||
| method.threshold_scale_factor = {"prefill": 437.5, "decode": 500.0} | ||
|
|
||
| info = method.get_threshold_info() | ||
|
|
||
| assert info["type"] == "dynamic" | ||
| assert info["scale_factors"] == {"prefill": 437.5, "decode": 500.0} | ||
| assert info["formula"] == "λ[phase] / length" | ||
| assert "phases" in info | ||
| assert "prefill" in info["phases"] | ||
| assert "decode" in info["phases"] | ||
| # Check example thresholds for prefill | ||
| prefill_examples = info["phases"]["prefill"]["example_thresholds"] | ||
| assert abs(prefill_examples[1024] - 437.5 / 1024) < 1e-6 | ||
| assert abs(prefill_examples[2048] - 437.5 / 2048) < 1e-6 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamic threshold tests target an outdated threshold-info contract.
FlashSkipSoftmax.get_threshold_info() now reports calibration via calibration_params/target_sparse_ratio and a different type payload. These tests still set threshold_scale_factor and assert type == "dynamic" with scale_factors and "λ=" formatting, so they’ll fail unless the contract is aligned. Please update the tests to match the new payload (or restore legacy fields in the method).
🛠 Example update for dynamic-threshold expectations
- # Simulate calibration setting per-phase scale factors
- method.threshold_scale_factor = {"prefill": 437.5, "decode": 500.0}
+ # Simulate calibration params + target sparsity
+ method.calibration_params = {
+ "prefill": {"k": 437.5, "p": 1.0},
+ "decode": {"k": 500.0, "p": 1.0},
+ }
+ method.target_sparse_ratio = {"prefill": 0.0, "decode": 0.0}
- assert info["type"] == "dynamic"
- assert info["scale_factors"] == {"prefill": 437.5, "decode": 500.0}
- assert info["formula"] == "λ[phase] / length"
+ assert info["type"] == "dynamic_calibrated"
+ assert info["phases"]["prefill"]["scale_factor"] == pytest.approx(437.5)
+ assert info["phases"]["decode"]["scale_factor"] == pytest.approx(500.0)
+ assert "formula" in infoAlso applies to: 136-170, 233-269
🤖 Prompt for AI Agents
In `@tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py` around
lines 52 - 79, The test expects the old threshold-info contract from
FlashSkipSoftmax.get_threshold_info (checking type=="dynamic",
threshold_scale_factor/scale_factors, and formula "λ[phase] / length"), but the
implementation now returns calibration_params and target_sparse_ratio with a
different type/payload; update the tests to assert the new contract: call
FlashSkipSoftmax.get_threshold_info(), assert the returned info uses the new
keys (e.g., info["type"] matches the new value, info["calibration_params"]
contains per-phase calibration data, and info["target_sparse_ratio"] or
equivalent exists), replace assertions that reference
threshold_scale_factor/scale_factors and the old formula with checks against
info["calibration_params"] structure and example thresholds derived from
calibration_params (or alternatively restore legacy fields in
FlashSkipSoftmax.get_threshold_info to emit scale_factors for backward
compatibility). Ensure references to FlashSkipSoftmax.get_threshold_info and the
instance attribute threshold_scale_factor are updated/removed across the other
affected test blocks (lines ~136-170 and ~233-269).
33d0025 to
31e8b2d
Compare
31e8b2d to
2fc1734
Compare
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
2fc1734 to
0553ec6
Compare
What does this PR do?
Type of change: ?
new feature
Overview: ?
Usage
The calibration method
Calibration Algorithm
Why Choosing the Inverse Power model?
The inverse power model better fits the relationship between sparsity ratio and threshold_scale_factor.

Runtime Flexibility
Testing
The calibration results for
Qwen/Qwen3-30B-A3B-Thinking-2507are shown below and are mostly consistent with the ground-truth numbers collected from the kernel side.Before your PR is "Ready for review"
Additional Information