[6/n] Replace skip-softmax calibration formula#1541
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThis PR overhauls skip-softmax sparse attention calibration by replacing an exponential scale-factor model with a new dynamic-threshold formula ChangesSkip-Softmax Dynamic Threshold Calibration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
…1 - S))^b / L^c) Signed-off-by: Kai Xu <kaix@nvidia.com>
14d4b63 to
a42daf0
Compare
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1541 +/- ##
==========================================
- Coverage 76.81% 76.77% -0.04%
==========================================
Files 476 476
Lines 51891 51905 +14
==========================================
- Hits 39860 39852 -8
- Misses 12031 12053 +22
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
129-135:⚠️ Potential issue | 🟠 Major | ⚡ Quick winGuard calibration-mode cleanup with
try/finally.If
forward_loop(model)throws, Line 134 is skipped and modules can remain in calibration mode. Wrap enable/forward/extract/disable intry/finallyso state is always restored.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around lines 129 - 135, Wrap the calibration-mode lifecycle in a try/finally so modules are always restored: keep the call to self._set_thresholds(attention_modules, self.threshold_trials) outside, then call self._enable_calibration_mode(attention_modules) and enter torch.no_grad(); inside a try run forward_loop(model) and assign per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase); in the finally always call self._disable_calibration_mode(attention_modules) so calibration mode is disabled even if forward_loop or _extract_calibration_stats raises; let exceptions propagate after cleanup.examples/diffusers/sparsity/wan22_skip_softmax.py (1)
173-177:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate stale “exponential model” wording to “dynamic threshold model.”
Line 176 and Line 279 still describe calibration as exponential, which now conflicts with the updated
(a, b, c)formulation used elsewhere in this file.Also applies to: 279-283
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 173 - 177, Update the stale wording "exponential model" to "dynamic threshold model" wherever calibration is described: change the help string on the parser.add_argument("--calibrate") entry and the later explanatory text that references the calibration and the (a, b, c) formulation (the block around the calibration description). Ensure the phrase "exponential model" is replaced with "dynamic threshold model" and keep mention of the (a, b, c) formulation intact so the help and comments accurately reflect the current algorithm.
🧹 Nitpick comments (1)
tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py (1)
47-60: ⚡ Quick winAdd a focused negative-path test for invalid
length_exponentvalues.The new coverage verifies happy-path set/get/reset, but it doesn’t protect the boundary contract for malformed exponents (
0, negative,nan,inf).Suggested test addition
class TestThreadLocalContext: @@ def test_clear_context_resets_all(self, ltx_mod): @@ assert length_exponent == 1.0 # default reset value + + `@pytest.mark.parametrize`("bad_exponent", [0.0, -1.0, float("nan"), float("inf")]) + def test_set_context_rejects_invalid_length_exponent(self, ltx_mod, bad_exponent): + with pytest.raises(ValueError): + ltx_mod.set_ltx_triton_context(active=True, length_exponent=bad_exponent)As per coding guidelines, "Tests: add/adjust focused unit tests to cover new logic and guard against regressions."
Also applies to: 75-88
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py` around lines 47 - 60, Add a focused negative-path unit test in the same file to validate that set_ltx_triton_context rejects invalid length_exponent inputs: call set_ltx_triton_context (as used in test_set_context_populates_fields) with length_exponent values 0, negative (e.g. -1), float('nan') and float('inf') and assert it raises the expected exception (e.g. ValueError/TypeError) for each case; locate usage around test_set_context_populates_fields and the helper _get_ltx_triton_context to mirror style and also add equivalent checks near the other related tests around lines 75-88 so the boundary contract is protected.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py`:
- Around line 56-57: Validate and sanitize the caller-controlled length_exponent
before writing it to thread-local context: ensure length_exponent is finite (not
NaN/Inf) and > 0 (and optionally clamp to a reasonable max like <= 10 or a
configured upper bound) and if invalid either replace with a safe default (e.g.,
1.0) or raise/return an error; then persist only the sanitized value to the
thread-local storage used by skip_softmax_threshold and the kernel path
(referencing the length_exponent variable and skip_softmax_threshold usage) so
malformed or extreme inputs cannot propagate into the kernel.
In `@modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py`:
- Line 44: In the LTX context setter where length_exponent is cast/storage (the
LTX context setter that assigns length_exponent), add a domain check to validate
the value is finite and > 0 before saving it; if it is not finite or <= 0, raise
a ValueError with a clear message. This prevents producing invalid thresholds
downstream (used later in the threshold computation around the attention
threshold code that reads length_exponent). Ensure you perform this check at the
input boundary (in the setter) so internal code (e.g., the attention threshold
calculation) can assume a valid positive finite length_exponent.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 267-287: Add a backward-compatible, ignored alias field named
fit_logspace to the CalibrationConfig Pydantic model so old checkpoints still
load: declare fit_logspace: Optional[bool] = None on the CalibrationConfig class
(which inherits ModeloptBaseConfig), and in a validator or __init__ detect if
fit_logspace is not None and emit a deprecation warning (warnings.warn or
process logger) saying it is ignored; do the same pattern for any other removed
calibration fields referenced in the same block (lines 289-347) to ensure
loading succeeds while preserving behavior.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 159-163: The reported thresholds in get_threshold_info() must be
numerically aligned with the runtime clipping used by
calc_correction_factor_and_p(): apply the same clipping/transform (clip t into
(1e-15, 1-1e-15) and report np.log(t) rather than the raw un-clipped expression)
so the debug output never shows exact 0.0/1.0 while runtime uses clipped values;
update get_threshold_info() (and the analogous block around lines 359-368) to
compute t, then clip using the same bounds used in
calc_correction_factor_and_p(), and then set log_thresholds = [np.log(t)] so
reported thresholds match the runtime path.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py`:
- Around line 86-93: Wrap the computation of threshold in
_resolve_skip_softmax_calibration() (the line computing threshold = 1.0 -
math.exp(-scale / (seq_len**c))) in a try/except that catches OverflowError and
ZeroDivisionError; on exception issue a warnings.warn describing invalid
calibration due to numeric overflow/underflow and return early without setting
skip_softmax_threshold so the caller falls back safely, preserving the existing
defensive clamp/warn logic for non-exceptional out-of-range threshold values.
---
Outside diff comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 173-177: Update the stale wording "exponential model" to "dynamic
threshold model" wherever calibration is described: change the help string on
the parser.add_argument("--calibrate") entry and the later explanatory text that
references the calibration and the (a, b, c) formulation (the block around the
calibration description). Ensure the phrase "exponential model" is replaced with
"dynamic threshold model" and keep mention of the (a, b, c) formulation intact
so the help and comments accurately reflect the current algorithm.
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 129-135: Wrap the calibration-mode lifecycle in a try/finally so
modules are always restored: keep the call to
self._set_thresholds(attention_modules, self.threshold_trials) outside, then
call self._enable_calibration_mode(attention_modules) and enter torch.no_grad();
inside a try run forward_loop(model) and assign per_sample_stats =
self._extract_calibration_stats(attention_modules, phase=phase); in the finally
always call self._disable_calibration_mode(attention_modules) so calibration
mode is disabled even if forward_loop or _extract_calibration_stats raises; let
exceptions propagate after cleanup.
---
Nitpick comments:
In `@tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py`:
- Around line 47-60: Add a focused negative-path unit test in the same file to
validate that set_ltx_triton_context rejects invalid length_exponent inputs:
call set_ltx_triton_context (as used in test_set_context_populates_fields) with
length_exponent values 0, negative (e.g. -1), float('nan') and float('inf') and
assert it raises the expected exception (e.g. ValueError/TypeError) for each
case; locate usage around test_set_context_populates_fields and the helper
_get_ltx_triton_context to mirror style and also add equivalent checks near the
other related tests around lines 75-88 so the boundary contract is protected.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b67affa7-74f0-48ad-b41d-b26996391748
📒 Files selected for processing (20)
examples/diffusers/README.mdexamples/diffusers/sparsity/README.mdexamples/diffusers/sparsity/wan22_skip_softmax.pymodelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.pymodelopt/torch/kernels/sparsity/attention/ltx_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrator.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/plugins/vllm.pytests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.pytests/unit/torch/sparsity/attention_sparsity/test_calibrator_fitting.pytests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_worker.pytests/unit/torch/sparsity/attention_sparsity/test_threshold_info.pytests/unit/torch/sparsity/attention_sparsity/test_triton_skip_softmax.py
| length_exponent: float = 1.0, | ||
| measure_sparsity: bool = False, |
There was a problem hiding this comment.
Validate length_exponent before persisting it to thread-local context.
length_exponent is caller-controlled, but at Line 81 it is accepted without bounds/finite checks. Values like <= 0, nan, or inf can produce invalid skip_softmax_threshold at Line 197 and propagate bad inputs to the kernel path.
Proposed fix
def set_triton_skip_softmax_config(
@@
length_exponent: float = 1.0,
@@
) -> None:
@@
- _thread_local.length_exponent = float(length_exponent)
+ length_exponent = float(length_exponent)
+ if not math.isfinite(length_exponent) or length_exponent <= 0.0:
+ raise ValueError("length_exponent must be a finite positive float")
+ _thread_local.length_exponent = length_exponentAs per coding guidelines, "Treat model artifacts/config/calibration data as untrusted; validate at interface boundaries and guard for malformed/extreme inputs."
Also applies to: 81-81, 196-197
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py`
around lines 56 - 57, Validate and sanitize the caller-controlled
length_exponent before writing it to thread-local context: ensure
length_exponent is finite (not NaN/Inf) and > 0 (and optionally clamp to a
reasonable max like <= 10 or a configured upper bound) and if invalid either
replace with a safe default (e.g., 1.0) or raise/return an error; then persist
only the sanitized value to the thread-local storage used by
skip_softmax_threshold and the kernel path (referencing the length_exponent
variable and skip_softmax_threshold usage) so malformed or extreme inputs cannot
propagate into the kernel.
| calibration_mode: bool = False, | ||
| threshold_trials: list[float] | None = None, | ||
| scale_factor: float | None = None, | ||
| length_exponent: float = 1.0, |
There was a problem hiding this comment.
Add domain checks for length_exponent in the LTX context setter.
At Line 53, length_exponent is cast and stored without validation. Non-finite or non-positive values can yield invalid thresholds in the computation at Line 154.
Proposed fix
def set_ltx_triton_context(
@@
length_exponent: float = 1.0,
@@
) -> None:
@@
- _thread_local.length_exponent = float(length_exponent)
+ length_exponent = float(length_exponent)
+ if not math.isfinite(length_exponent) or length_exponent <= 0.0:
+ raise ValueError("length_exponent must be a finite positive float")
+ _thread_local.length_exponent = length_exponentAs per coding guidelines, "Validate external input once at boundaries; internal code can trust those checks."
Also applies to: 53-53, 152-154
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py` at line
44, In the LTX context setter where length_exponent is cast/storage (the LTX
context setter that assigns length_exponent), add a domain check to validate the
value is finite and > 0 before saving it; if it is not finite or <= 0, raise a
ValueError with a clear message. This prevents producing invalid thresholds
downstream (used later in the threshold computation around the attention
threshold code that reads length_exponent). Ensure you perform this check at the
input boundary (in the setter) so internal code (e.g., the attention threshold
calculation) can assume a valid positive finite length_exponent.
| class CalibrationConfig(ModeloptBaseConfig): | ||
| """Configuration for automatic threshold calibration using RULER dataset. | ||
|
|
||
| Calibration fits an Exponential model to determine dynamic thresholds that | ||
| achieve target sparsity. The model learns parameters a and b per phase: | ||
| Calibration determines dynamic thresholds that achieve target sparsity. | ||
| Three parameters (a, b, c) are fit per phase via closed-form linear | ||
| regression on | ||
|
|
||
| scale_factor = a * exp(b * target_sparsity) | ||
| log(-log(1 - t)) = log(a) + b * logit(S) - c * log(L) | ||
|
|
||
| At inference time, the threshold is computed as: | ||
| At inference time, the threshold is computed as | ||
|
|
||
| threshold = scale_factor / sequence_length | ||
| threshold = 1 - exp(-a * (S / (1 - S))^b / L^c) | ||
|
|
||
| Key benefits: | ||
| Key properties: | ||
| - Bounded in (0, 1) by construction (no runtime clamp needed) | ||
| - Correct asymptotes: t->0 as S->0 or L->inf; t->1 as S->1 or L->0 | ||
| - Target sparsity can be changed at runtime without recalibration | ||
| - Threshold automatically adapts to sequence length | ||
| - Threshold adapts to sequence length via the L^c term | ||
| - Supports independent prefill and decode phase calibration | ||
| - Exponential model provides better fit (lower RMSE) | ||
| - Closed-form linear fit (np.linalg.lstsq); no nonlinear curve_fit needed | ||
| """ |
There was a problem hiding this comment.
Preserve backward compatibility for removed calibration fields.
CalibrationConfig no longer exposes fit_logspace; without a deprecated-acceptance shim, older saved configs/checkpoints may fail to load. Please keep a backward-compatible field/alias (ignored with warning) for at least one migration window.
As per coding guidelines: “**/config*.py: Preserve config and checkpoint backward compatibility for Pydantic-based ModeloptBaseConfig instances such as QuantizeConfig.”
Also applies to: 289-347
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 267 - 287,
Add a backward-compatible, ignored alias field named fit_logspace to the
CalibrationConfig Pydantic model so old checkpoints still load: declare
fit_logspace: Optional[bool] = None on the CalibrationConfig class (which
inherits ModeloptBaseConfig), and in a validator or __init__ detect if
fit_logspace is not None and emit a deprecation warning (warnings.warn or
process logger) saying it is ignored; do the same pattern for any other removed
calibration fields referenced in the same block (lines 289-347) to ensure
loading succeeds while preserving behavior.
| s_clipped = float(np.clip(target_sparsity, 1e-6, 1.0 - 1e-6)) | ||
| scale = a * (s_clipped / (1.0 - s_clipped)) ** b | ||
| t = 1.0 - np.exp(-scale / (float(seq_k) ** c)) | ||
| t = float(np.clip(t, 1e-15, 1.0 - 1e-15)) | ||
| log_thresholds = [np.log(t)] |
There was a problem hiding this comment.
Keep threshold reporting numerically aligned with the runtime path.
calc_correction_factor_and_p() clamps calibrated thresholds into (1e-15, 1 - 1e-15), but get_threshold_info() recomputes the raw expression. For extreme (a, b, c, S, L) values that can report exact 0.0/1.0 even though the runtime path never uses those values, which makes threshold debugging misleading.
Proposed fix
- t = 1.0 - np.exp(-scale / (float(seq_k) ** c))
+ t = -np.expm1(-scale / (float(seq_k) ** c))
t = float(np.clip(t, 1e-15, 1.0 - 1e-15))
log_thresholds = [np.log(t)]- length: float(1.0 - np.exp(-scale / (length**c)))
+ length: float(
+ np.clip(-np.expm1(-scale / (length**c)), 1e-15, 1.0 - 1e-15)
+ )
for length in example_lengthsAlso applies to: 359-368
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 159 - 163, The reported thresholds in get_threshold_info() must be
numerically aligned with the runtime clipping used by
calc_correction_factor_and_p(): apply the same clipping/transform (clip t into
(1e-15, 1-1e-15) and report np.log(t) rather than the raw un-clipped expression)
so the debug output never shows exact 0.0/1.0 while runtime uses clipped values;
update get_threshold_info() (and the analogous block around lines 359-368) to
compute t, then clip using the same bounds used in
calc_correction_factor_and_p(), and then set log_thresholds = [np.log(t)] so
reported thresholds match the runtime path.
| s_clipped = min(max(target, 1e-6), 1.0 - 1e-6) | ||
| scale = a * (s_clipped / (1.0 - s_clipped)) ** b | ||
| threshold = 1.0 - math.exp(-scale / (seq_len**c)) | ||
| # Sanity check: the envelope guarantees threshold ∈ (0, 1) for finite | ||
| # positive scale and seq_len, but clamp defensively against pathological | ||
| # parameter values (e.g., scale=0 -> threshold=0, which disables sparsity). | ||
| if not (0.0 < threshold < 1.0): | ||
| warnings.warn( |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py"
echo "== File: $FILE =="
rg -n "threshold\s*=|seq_len\*\*c|OverflowError|warnings\.warn|_resolve_skip_softmax_calibration|calibration" "$FILE" || true
echo
echo "== Context around lines 70-120 =="
nl -ba "$FILE" | sed -n '60,130p'
echo
echo "== Context around lines 220-270 =="
nl -ba "$FILE" | sed -n '210,270p'Repository: NVIDIA/Model-Optimizer
Length of output: 796
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py"
echo "== File: $FILE =="
rg -n "def _resolve_skip_softmax_calibration|except|OverflowError|seq_len\*\*c|math\.exp|warnings\.warn" "$FILE" || true
echo
echo "== Lines 1-180 =="
awk 'NR>=1 && NR<=180 {printf "%5d:%s\n", NR, $0}' "$FILE"
echo
echo "== Lines 180-280 =="
awk 'NR>=180 && NR<=280 {printf "%5d:%s\n", NR, $0}' "$FILE"Repository: NVIDIA/Model-Optimizer
Length of output: 13617
🏁 Script executed:
python3 - <<'PY'
import math
def test(seq_len, a, s_clipped, b, c):
try:
scale = a * (s_clipped / (1.0 - s_clipped)) ** b
threshold = 1.0 - math.exp(-scale / (seq_len**c))
return ("ok", scale, seq_len**c, threshold)
except Exception as e:
return ("err", type(e).__name__, str(e))
cases = [
# Extreme positive c
(2048, 1.0, 0.5, 10.0, 100.0),
(2048, 1.0, 0.5, 10.0, 300.0),
(2048, 1.0, 0.5, 10.0, 1000.0),
# Extreme negative c
(2048, 1.0, 0.5, 10.0, -100.0),
(2048, 1.0, 0.5, 10.0, -300.0),
# Extreme b (affects scale)
(2048, 1.0, 0.5, 300.0, 1.0),
(2048, 1.0, 0.5, 1000.0, 1.0),
# Very small s_clipped / near 1
(2048, 1.0, 1e-6, 100.0, 10.0),
(2048, 1.0, 1-1e-6, 100.0, 10.0),
]
for i, (seq_len,a,s_clipped,b,c) in enumerate(cases, 1):
res = test(seq_len,a,s_clipped,b,c)
print(i, res)
PYRepository: NVIDIA/Model-Optimizer
Length of output: 584
Guard skip-softmax calibration threshold math against OverflowError/ZeroDivisionError
In _resolve_skip_softmax_calibration() (lines 86-93), threshold = 1.0 - math.exp(-scale / (seq_len**c)) is computed without handling numeric overflow/underflow: extreme checkpoint-derived c can make seq_len**c overflow or underflow to 0, raising OverflowError/ZeroDivisionError and aborting the launch before the existing (0, 1) warning/clamp path runs. Catch these exceptions and treat the calibration as invalid (warn + return without setting skip_softmax_threshold) so the code falls back safely.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py` around lines 86 -
93, Wrap the computation of threshold in _resolve_skip_softmax_calibration()
(the line computing threshold = 1.0 - math.exp(-scale / (seq_len**c))) in a
try/except that catches OverflowError and ZeroDivisionError; on exception issue
a warnings.warn describing invalid calibration due to numeric overflow/underflow
and return early without setting skip_softmax_threshold so the caller falls back
safely, preserving the existing defensive clamp/warn logic for non-exceptional
out-of-range threshold values.
What does this PR do?
Type of change: ?
Replace the old calibration formula from
t = a * exp(b * S) / Ltot = 1 - exp(-a * (S / (1 - S))^b / seq_k^c. The old calibration breaks at short context. We can clamp at runtime or add1 - exp(-...)wrapper around t because1 − exp(−c/L^α) \in (0, 1)strictly, so we don't need to hardcode where to cutoff.Usage
Testing
The calibration curve are shown below.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
t = 1 - exp(-a * (S/(1-S))^b / L^c)) enabling runtime sparsity adjustment without recalibration.Refactor
fit_logspaceconfiguration option.Documentation