Skip to content

[6/n] Replace skip-softmax calibration formula#1541

Open
kaix-nv wants to merge 1 commit into
mainfrom
kaix/skip_softmax_calib_formula
Open

[6/n] Replace skip-softmax calibration formula#1541
kaix-nv wants to merge 1 commit into
mainfrom
kaix/skip_softmax_calib_formula

Conversation

@kaix-nv
Copy link
Copy Markdown
Contributor

@kaix-nv kaix-nv commented May 25, 2026

What does this PR do?

Type of change: ?
Replace the old calibration formula from t = a * exp(b * S) / L to t = 1 - exp(-a * (S / (1 - S))^b / seq_k^c. The old calibration breaks at short context. We can clamp at runtime or add 1 - exp(-...) wrapper around t because 1 − exp(−c/L^α) \in (0, 1)  strictly, so we don't need to hardcode where to cutoff.

Usage

Testing

The calibration curve are shown below.

calibration_qwen3_30b_a3b calibration_qwen3_8b calibration_llama_3_1_8b

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Enhanced skip-softmax sparse attention with a new dynamic threshold calibration formula (t = 1 - exp(-a * (S/(1-S))^b / L^c)) enabling runtime sparsity adjustment without recalibration.
  • Refactor

    • Replaced exponential-based threshold computation with calibrated parameters across Triton and LTX attention kernels; removed fit_logspace configuration option.
  • Documentation

    • Updated sparse attention calibration documentation to reflect the new dynamic threshold model and calibration workflow.

Review Change Stack

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 25, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 25, 2026

📝 Walkthrough

Walkthrough

This PR overhauls skip-softmax sparse attention calibration by replacing an exponential scale-factor model with a new dynamic-threshold formula t = 1 - exp(-a * (S/(1-S))^b / L^c) fitted via closed-form least-squares regression. The change removes the fit_logspace configuration option and introduces a length-exponent parameter c throughout calibration, kernel interfaces, and inference methods.

Changes

Skip-Softmax Dynamic Threshold Calibration

Layer / File(s) Summary
Documentation and Configuration Schema
examples/diffusers/README.md, examples/diffusers/sparsity/*, modelopt/torch/sparsity/attention_sparsity/config.py
Updated example documentation to describe the new dynamic threshold formula t = 1 - exp(-a * (S/(1-S))^b / L^c) and removed fit_logspace configuration field from CalibrationConfig.
Calibration Algorithm Rewrite
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
Replaced scipy curve_fit-based fitting with closed-form least-squares regression that collects (t, L, S) triples, transforms coordinates via logit and log, filters invalid data, and computes (a, b, c) coefficients. Returns calibration_type as "dynamic_threshold".
Calibration Orchestration
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
Updated calibrate_sparse_attention to extract and store the new c parameter from per-phase calibration results and updated printed model equation to show the new threshold formula.
Kernel Interface - Diffusers Triton
modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py
Added length_exponent parameter to set_triton_skip_softmax_config() and updated inference-mode threshold computation to use 1 - exp(-scale_factor / (seq_k**length_exponent)).
Kernel Interface - LTX Triton
modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py
Added length_exponent parameter to set_ltx_triton_context(), included it in context getter return tuple, and updated dynamic threshold to 1 - exp(-scale_factor / (seq_k**length_exponent)).
Export and Conversion Schema
modelopt/torch/sparsity/attention_sparsity/conversion.py
Updated export_sparse_attention_config() to emit threshold_scale_factor with per-phase (a, b, c) parameters and new formula. Adjusted _format_threshold() to include c in human-readable output.
FlashSkipSoftmax Method
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Updated calibrated threshold computation to 1 - exp(-scale / (seq_k**c)) with per-phase parameters; updated get_threshold_info() to report example thresholds and c value.
TritonSkipSoftmaxMethod
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
Updated to compute and pass length_exponent alongside scale_factor to kernel backends. Replaced calibrated scale computation with power-law transform a * (S/(1-S))^b.
vLLM Plugin Runtime Resolution
modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
Updated _resolve_skip_softmax_calibration() to accept c parameter and compute threshold as 1 - exp(-scale / seq_len^c) with (0, 1) range validation.
Unit Tests - Kernel, Calibrator, Methods
tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py, tests/unit/torch/sparsity/attention_sparsity/test_calibrator_fitting.py, tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py, tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py
Updated all tests to validate (a, b, c) parameter recovery, new threshold formulas, and length-exponent context handling. Removed old linear/log-space fitting tests; added single comprehensive parameter recovery test.
Unit Tests - Conversion, Worker, Threshold Info
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py, tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_*.py, tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
Updated integration tests to validate new threshold_scale_factor schema with (a, b, c), phase-aware calibration resolution, (0, 1) validity-range validation, and per-phase threshold info reporting.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically summarizes the main change: replacing the skip-softmax calibration formula with a new mathematical formulation.
Docstring Coverage ✅ Passed Docstring coverage is 80.39% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No critical security anti-patterns found. PR contains only mathematical/algorithmic updates to skip-softmax calibration—no unsafe deserialization, code execution, or dependency additions.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kaix/skip_softmax_calib_formula

Comment @coderabbitai help to get the list of available commands and usage tips.

…1 - S))^b / L^c)

Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/skip_softmax_calib_formula branch from 14d4b63 to a42daf0 Compare May 25, 2026 01:47
@github-actions
Copy link
Copy Markdown
Contributor

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1541/

Built to branch gh-pages at 2026-05-25 01:50 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 25, 2026

Codecov Report

❌ Patch coverage is 89.88764% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.77%. Comparing base (16a0130) to head (a42daf0).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
.../torch/sparsity/attention_sparsity/plugins/vllm.py 0.00% 5 Missing ⚠️
.../attention_sparsity/methods/triton_skip_softmax.py 66.66% 3 Missing ⚠️
...rsity/attention_sparsity/calibration/calibrator.py 97.82% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1541      +/-   ##
==========================================
- Coverage   76.81%   76.77%   -0.04%     
==========================================
  Files         476      476              
  Lines       51891    51905      +14     
==========================================
- Hits        39860    39852       -8     
- Misses      12031    12053      +22     
Flag Coverage Δ
examples 41.65% <77.52%> (+0.91%) ⬆️
gpu 59.49% <11.23%> (-0.59%) ⬇️
regression 15.19% <1.12%> (+0.06%) ⬆️
unit 52.71% <84.26%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv changed the title Replace skip-softmax calibration with threshold = 1 - exp(-a * (S / (1 - S))^b / L^c) [6/n] Replace skip-softmax calibration with threshold = 1 - exp(-a * (S / (1 - S))^b / L^c) May 25, 2026
@kaix-nv kaix-nv changed the title [6/n] Replace skip-softmax calibration with threshold = 1 - exp(-a * (S / (1 - S))^b / L^c) [6/n] Replace skip-softmax calibration formula May 25, 2026
@kaix-nv kaix-nv marked this pull request as ready for review May 26, 2026 04:27
@kaix-nv kaix-nv requested review from a team as code owners May 26, 2026 04:27
@kaix-nv kaix-nv requested review from kevalmorabia97, realAsma and yeyu-nvidia and removed request for kevalmorabia97, realAsma and yeyu-nvidia May 26, 2026 04:27
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)

129-135: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard calibration-mode cleanup with try/finally.

If forward_loop(model) throws, Line 134 is skipped and modules can remain in calibration mode. Wrap enable/forward/extract/disable in try/finally so state is always restored.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 129 - 135, Wrap the calibration-mode lifecycle in a try/finally so modules
are always restored: keep the call to self._set_thresholds(attention_modules,
self.threshold_trials) outside, then call
self._enable_calibration_mode(attention_modules) and enter torch.no_grad();
inside a try run forward_loop(model) and assign per_sample_stats =
self._extract_calibration_stats(attention_modules, phase=phase); in the finally
always call self._disable_calibration_mode(attention_modules) so calibration
mode is disabled even if forward_loop or _extract_calibration_stats raises; let
exceptions propagate after cleanup.
examples/diffusers/sparsity/wan22_skip_softmax.py (1)

173-177: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update stale “exponential model” wording to “dynamic threshold model.”

Line 176 and Line 279 still describe calibration as exponential, which now conflicts with the updated (a, b, c) formulation used elsewhere in this file.

Also applies to: 279-283

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 173 - 177,
Update the stale wording "exponential model" to "dynamic threshold model"
wherever calibration is described: change the help string on the
parser.add_argument("--calibrate") entry and the later explanatory text that
references the calibration and the (a, b, c) formulation (the block around the
calibration description). Ensure the phrase "exponential model" is replaced with
"dynamic threshold model" and keep mention of the (a, b, c) formulation intact
so the help and comments accurately reflect the current algorithm.
🧹 Nitpick comments (1)
tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py (1)

47-60: ⚡ Quick win

Add a focused negative-path test for invalid length_exponent values.

The new coverage verifies happy-path set/get/reset, but it doesn’t protect the boundary contract for malformed exponents (0, negative, nan, inf).

Suggested test addition
 class TestThreadLocalContext:
@@
     def test_clear_context_resets_all(self, ltx_mod):
@@
         assert length_exponent == 1.0  # default reset value
+
+    `@pytest.mark.parametrize`("bad_exponent", [0.0, -1.0, float("nan"), float("inf")])
+    def test_set_context_rejects_invalid_length_exponent(self, ltx_mod, bad_exponent):
+        with pytest.raises(ValueError):
+            ltx_mod.set_ltx_triton_context(active=True, length_exponent=bad_exponent)

As per coding guidelines, "Tests: add/adjust focused unit tests to cover new logic and guard against regressions."

Also applies to: 75-88

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py`
around lines 47 - 60, Add a focused negative-path unit test in the same file to
validate that set_ltx_triton_context rejects invalid length_exponent inputs:
call set_ltx_triton_context (as used in test_set_context_populates_fields) with
length_exponent values 0, negative (e.g. -1), float('nan') and float('inf') and
assert it raises the expected exception (e.g. ValueError/TypeError) for each
case; locate usage around test_set_context_populates_fields and the helper
_get_ltx_triton_context to mirror style and also add equivalent checks near the
other related tests around lines 75-88 so the boundary contract is protected.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py`:
- Around line 56-57: Validate and sanitize the caller-controlled length_exponent
before writing it to thread-local context: ensure length_exponent is finite (not
NaN/Inf) and > 0 (and optionally clamp to a reasonable max like <= 10 or a
configured upper bound) and if invalid either replace with a safe default (e.g.,
1.0) or raise/return an error; then persist only the sanitized value to the
thread-local storage used by skip_softmax_threshold and the kernel path
(referencing the length_exponent variable and skip_softmax_threshold usage) so
malformed or extreme inputs cannot propagate into the kernel.

In `@modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py`:
- Line 44: In the LTX context setter where length_exponent is cast/storage (the
LTX context setter that assigns length_exponent), add a domain check to validate
the value is finite and > 0 before saving it; if it is not finite or <= 0, raise
a ValueError with a clear message. This prevents producing invalid thresholds
downstream (used later in the threshold computation around the attention
threshold code that reads length_exponent). Ensure you perform this check at the
input boundary (in the setter) so internal code (e.g., the attention threshold
calculation) can assume a valid positive finite length_exponent.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 267-287: Add a backward-compatible, ignored alias field named
fit_logspace to the CalibrationConfig Pydantic model so old checkpoints still
load: declare fit_logspace: Optional[bool] = None on the CalibrationConfig class
(which inherits ModeloptBaseConfig), and in a validator or __init__ detect if
fit_logspace is not None and emit a deprecation warning (warnings.warn or
process logger) saying it is ignored; do the same pattern for any other removed
calibration fields referenced in the same block (lines 289-347) to ensure
loading succeeds while preserving behavior.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 159-163: The reported thresholds in get_threshold_info() must be
numerically aligned with the runtime clipping used by
calc_correction_factor_and_p(): apply the same clipping/transform (clip t into
(1e-15, 1-1e-15) and report np.log(t) rather than the raw un-clipped expression)
so the debug output never shows exact 0.0/1.0 while runtime uses clipped values;
update get_threshold_info() (and the analogous block around lines 359-368) to
compute t, then clip using the same bounds used in
calc_correction_factor_and_p(), and then set log_thresholds = [np.log(t)] so
reported thresholds match the runtime path.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py`:
- Around line 86-93: Wrap the computation of threshold in
_resolve_skip_softmax_calibration() (the line computing threshold = 1.0 -
math.exp(-scale / (seq_len**c))) in a try/except that catches OverflowError and
ZeroDivisionError; on exception issue a warnings.warn describing invalid
calibration due to numeric overflow/underflow and return early without setting
skip_softmax_threshold so the caller falls back safely, preserving the existing
defensive clamp/warn logic for non-exceptional out-of-range threshold values.

---

Outside diff comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 173-177: Update the stale wording "exponential model" to "dynamic
threshold model" wherever calibration is described: change the help string on
the parser.add_argument("--calibrate") entry and the later explanatory text that
references the calibration and the (a, b, c) formulation (the block around the
calibration description). Ensure the phrase "exponential model" is replaced with
"dynamic threshold model" and keep mention of the (a, b, c) formulation intact
so the help and comments accurately reflect the current algorithm.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 129-135: Wrap the calibration-mode lifecycle in a try/finally so
modules are always restored: keep the call to
self._set_thresholds(attention_modules, self.threshold_trials) outside, then
call self._enable_calibration_mode(attention_modules) and enter torch.no_grad();
inside a try run forward_loop(model) and assign per_sample_stats =
self._extract_calibration_stats(attention_modules, phase=phase); in the finally
always call self._disable_calibration_mode(attention_modules) so calibration
mode is disabled even if forward_loop or _extract_calibration_stats raises; let
exceptions propagate after cleanup.

---

Nitpick comments:
In `@tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py`:
- Around line 47-60: Add a focused negative-path unit test in the same file to
validate that set_ltx_triton_context rejects invalid length_exponent inputs:
call set_ltx_triton_context (as used in test_set_context_populates_fields) with
length_exponent values 0, negative (e.g. -1), float('nan') and float('inf') and
assert it raises the expected exception (e.g. ValueError/TypeError) for each
case; locate usage around test_set_context_populates_fields and the helper
_get_ltx_triton_context to mirror style and also add equivalent checks near the
other related tests around lines 75-88 so the boundary contract is protected.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b67affa7-74f0-48ad-b41d-b26996391748

📥 Commits

Reviewing files that changed from the base of the PR and between df68ccd and a42daf0.

📒 Files selected for processing (20)
  • examples/diffusers/README.md
  • examples/diffusers/sparsity/README.md
  • examples/diffusers/sparsity/wan22_skip_softmax.py
  • modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py
  • modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
  • tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py
  • tests/unit/torch/sparsity/attention_sparsity/test_calibrator_fitting.py
  • tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.py
  • tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
  • tests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py

Comment on lines +56 to 57
length_exponent: float = 1.0,
measure_sparsity: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate length_exponent before persisting it to thread-local context.

length_exponent is caller-controlled, but at Line 81 it is accepted without bounds/finite checks. Values like <= 0, nan, or inf can produce invalid skip_softmax_threshold at Line 197 and propagate bad inputs to the kernel path.

Proposed fix
 def set_triton_skip_softmax_config(
@@
     length_exponent: float = 1.0,
@@
 ) -> None:
@@
-    _thread_local.length_exponent = float(length_exponent)
+    length_exponent = float(length_exponent)
+    if not math.isfinite(length_exponent) or length_exponent <= 0.0:
+        raise ValueError("length_exponent must be a finite positive float")
+    _thread_local.length_exponent = length_exponent

As per coding guidelines, "Treat model artifacts/config/calibration data as untrusted; validate at interface boundaries and guard for malformed/extreme inputs."

Also applies to: 81-81, 196-197

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py`
around lines 56 - 57, Validate and sanitize the caller-controlled
length_exponent before writing it to thread-local context: ensure
length_exponent is finite (not NaN/Inf) and > 0 (and optionally clamp to a
reasonable max like <= 10 or a configured upper bound) and if invalid either
replace with a safe default (e.g., 1.0) or raise/return an error; then persist
only the sanitized value to the thread-local storage used by
skip_softmax_threshold and the kernel path (referencing the length_exponent
variable and skip_softmax_threshold usage) so malformed or extreme inputs cannot
propagate into the kernel.

calibration_mode: bool = False,
threshold_trials: list[float] | None = None,
scale_factor: float | None = None,
length_exponent: float = 1.0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add domain checks for length_exponent in the LTX context setter.

At Line 53, length_exponent is cast and stored without validation. Non-finite or non-positive values can yield invalid thresholds in the computation at Line 154.

Proposed fix
 def set_ltx_triton_context(
@@
     length_exponent: float = 1.0,
@@
 ) -> None:
@@
-    _thread_local.length_exponent = float(length_exponent)
+    length_exponent = float(length_exponent)
+    if not math.isfinite(length_exponent) or length_exponent <= 0.0:
+        raise ValueError("length_exponent must be a finite positive float")
+    _thread_local.length_exponent = length_exponent

As per coding guidelines, "Validate external input once at boundaries; internal code can trust those checks."

Also applies to: 53-53, 152-154

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py` at line
44, In the LTX context setter where length_exponent is cast/storage (the LTX
context setter that assigns length_exponent), add a domain check to validate the
value is finite and > 0 before saving it; if it is not finite or <= 0, raise a
ValueError with a clear message. This prevents producing invalid thresholds
downstream (used later in the threshold computation around the attention
threshold code that reads length_exponent). Ensure you perform this check at the
input boundary (in the setter) so internal code (e.g., the attention threshold
calculation) can assume a valid positive finite length_exponent.

Comment on lines 267 to 287
class CalibrationConfig(ModeloptBaseConfig):
"""Configuration for automatic threshold calibration using RULER dataset.

Calibration fits an Exponential model to determine dynamic thresholds that
achieve target sparsity. The model learns parameters a and b per phase:
Calibration determines dynamic thresholds that achieve target sparsity.
Three parameters (a, b, c) are fit per phase via closed-form linear
regression on

scale_factor = a * exp(b * target_sparsity)
log(-log(1 - t)) = log(a) + b * logit(S) - c * log(L)

At inference time, the threshold is computed as:
At inference time, the threshold is computed as

threshold = scale_factor / sequence_length
threshold = 1 - exp(-a * (S / (1 - S))^b / L^c)

Key benefits:
Key properties:
- Bounded in (0, 1) by construction (no runtime clamp needed)
- Correct asymptotes: t->0 as S->0 or L->inf; t->1 as S->1 or L->0
- Target sparsity can be changed at runtime without recalibration
- Threshold automatically adapts to sequence length
- Threshold adapts to sequence length via the L^c term
- Supports independent prefill and decode phase calibration
- Exponential model provides better fit (lower RMSE)
- Closed-form linear fit (np.linalg.lstsq); no nonlinear curve_fit needed
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Preserve backward compatibility for removed calibration fields.

CalibrationConfig no longer exposes fit_logspace; without a deprecated-acceptance shim, older saved configs/checkpoints may fail to load. Please keep a backward-compatible field/alias (ignored with warning) for at least one migration window.

As per coding guidelines: “**/config*.py: Preserve config and checkpoint backward compatibility for Pydantic-based ModeloptBaseConfig instances such as QuantizeConfig.”

Also applies to: 289-347

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 267 - 287,
Add a backward-compatible, ignored alias field named fit_logspace to the
CalibrationConfig Pydantic model so old checkpoints still load: declare
fit_logspace: Optional[bool] = None on the CalibrationConfig class (which
inherits ModeloptBaseConfig), and in a validator or __init__ detect if
fit_logspace is not None and emit a deprecation warning (warnings.warn or
process logger) saying it is ignored; do the same pattern for any other removed
calibration fields referenced in the same block (lines 289-347) to ensure
loading succeeds while preserving behavior.

Comment on lines +159 to +163
s_clipped = float(np.clip(target_sparsity, 1e-6, 1.0 - 1e-6))
scale = a * (s_clipped / (1.0 - s_clipped)) ** b
t = 1.0 - np.exp(-scale / (float(seq_k) ** c))
t = float(np.clip(t, 1e-15, 1.0 - 1e-15))
log_thresholds = [np.log(t)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Keep threshold reporting numerically aligned with the runtime path.

calc_correction_factor_and_p() clamps calibrated thresholds into (1e-15, 1 - 1e-15), but get_threshold_info() recomputes the raw expression. For extreme (a, b, c, S, L) values that can report exact 0.0/1.0 even though the runtime path never uses those values, which makes threshold debugging misleading.

Proposed fix
-            t = 1.0 - np.exp(-scale / (float(seq_k) ** c))
+            t = -np.expm1(-scale / (float(seq_k) ** c))
             t = float(np.clip(t, 1e-15, 1.0 - 1e-15))
             log_thresholds = [np.log(t)]
-                        length: float(1.0 - np.exp(-scale / (length**c)))
+                        length: float(
+                            np.clip(-np.expm1(-scale / (length**c)), 1e-15, 1.0 - 1e-15)
+                        )
                         for length in example_lengths

Also applies to: 359-368

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 159 - 163, The reported thresholds in get_threshold_info() must be
numerically aligned with the runtime clipping used by
calc_correction_factor_and_p(): apply the same clipping/transform (clip t into
(1e-15, 1-1e-15) and report np.log(t) rather than the raw un-clipped expression)
so the debug output never shows exact 0.0/1.0 while runtime uses clipped values;
update get_threshold_info() (and the analogous block around lines 359-368) to
compute t, then clip using the same bounds used in
calc_correction_factor_and_p(), and then set log_thresholds = [np.log(t)] so
reported thresholds match the runtime path.

Comment on lines +86 to 93
s_clipped = min(max(target, 1e-6), 1.0 - 1e-6)
scale = a * (s_clipped / (1.0 - s_clipped)) ** b
threshold = 1.0 - math.exp(-scale / (seq_len**c))
# Sanity check: the envelope guarantees threshold ∈ (0, 1) for finite
# positive scale and seq_len, but clamp defensively against pathological
# parameter values (e.g., scale=0 -> threshold=0, which disables sparsity).
if not (0.0 < threshold < 1.0):
warnings.warn(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py"
echo "== File: $FILE =="
rg -n "threshold\s*=|seq_len\*\*c|OverflowError|warnings\.warn|_resolve_skip_softmax_calibration|calibration" "$FILE" || true
echo
echo "== Context around lines 70-120 =="
nl -ba "$FILE" | sed -n '60,130p'
echo
echo "== Context around lines 220-270 =="
nl -ba "$FILE" | sed -n '210,270p'

Repository: NVIDIA/Model-Optimizer

Length of output: 796


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py"
echo "== File: $FILE =="
rg -n "def _resolve_skip_softmax_calibration|except|OverflowError|seq_len\*\*c|math\.exp|warnings\.warn" "$FILE" || true
echo

echo "== Lines 1-180 =="
awk 'NR>=1 && NR<=180 {printf "%5d:%s\n", NR, $0}' "$FILE"

echo
echo "== Lines 180-280 =="
awk 'NR>=180 && NR<=280 {printf "%5d:%s\n", NR, $0}' "$FILE"

Repository: NVIDIA/Model-Optimizer

Length of output: 13617


🏁 Script executed:

python3 - <<'PY'
import math

def test(seq_len, a, s_clipped, b, c):
    try:
        scale = a * (s_clipped / (1.0 - s_clipped)) ** b
        threshold = 1.0 - math.exp(-scale / (seq_len**c))
        return ("ok", scale, seq_len**c, threshold)
    except Exception as e:
        return ("err", type(e).__name__, str(e))

cases = [
    # Extreme positive c
    (2048, 1.0, 0.5, 10.0, 100.0),
    (2048, 1.0, 0.5, 10.0, 300.0),
    (2048, 1.0, 0.5, 10.0, 1000.0),
    # Extreme negative c
    (2048, 1.0, 0.5, 10.0, -100.0),
    (2048, 1.0, 0.5, 10.0, -300.0),
    # Extreme b (affects scale)
    (2048, 1.0, 0.5, 300.0, 1.0),
    (2048, 1.0, 0.5, 1000.0, 1.0),
    # Very small s_clipped / near 1
    (2048, 1.0, 1e-6, 100.0, 10.0),
    (2048, 1.0, 1-1e-6, 100.0, 10.0),
]

for i, (seq_len,a,s_clipped,b,c) in enumerate(cases, 1):
    res = test(seq_len,a,s_clipped,b,c)
    print(i, res)
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 584


Guard skip-softmax calibration threshold math against OverflowError/ZeroDivisionError

In _resolve_skip_softmax_calibration() (lines 86-93), threshold = 1.0 - math.exp(-scale / (seq_len**c)) is computed without handling numeric overflow/underflow: extreme checkpoint-derived c can make seq_len**c overflow or underflow to 0, raising OverflowError/ZeroDivisionError and aborting the launch before the existing (0, 1) warning/clamp path runs. Catch these exceptions and treat the calibration as invalid (warn + return without setting skip_softmax_threshold) so the code falls back safely.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py` around lines 86 -
93, Wrap the computation of threshold in _resolve_skip_softmax_calibration()
(the line computing threshold = 1.0 - math.exp(-scale / (seq_len**c))) in a
try/except that catches OverflowError and ZeroDivisionError; on exception issue
a warnings.warn describing invalid calibration due to numeric overflow/underflow
and return early without setting skip_softmax_threshold so the caller falls back
safely, preserving the existing defensive clamp/warn logic for non-exceptional
out-of-range threshold values.

@kaix-nv kaix-nv requested a review from rohansjoshi May 28, 2026 19:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant