[OMNIML-3994] Make sure all weight quantizers have _amax#1560
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR removes the lazy weight calibration system for NVFP4 quantizers by deleting the bootstrap helper, refactoring ChangesLazy weight calibration removal
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
/claude review |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)
1328-1333: ⚡ Quick winRemove the unreachable
elsepath inpostprocess.At Line 1389-Line 1391,
postprocess()is only called whenmodule.awq_lite.is_enabledis true, so theelseblock at Line 1328-Line 1333 is dead code.♻️ Proposed simplification
- if module.awq_lite.is_enabled: - apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale) - else: - # ``max_calibrate`` already set ``_amax`` for this module via the - # always-on ``weight_only_quantize`` step; AWQ is just disabled - # because no cache tokens flowed through it. Falling back to - # neutral per-tensor max calibration is the right thing — and - # already done by ``max_calibrate`` — so just warn. - warnings.warn(f"awq_lite: Disabling for {name}, quantizing with max calibration.") + apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)As per coding guidelines "Remove dead code including unused imports, unreachable branches, and obsolete helpers".
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/model_calib.py` around lines 1328 - 1333, The else branch in postprocess is unreachable because postprocess() is only invoked when module.awq_lite.is_enabled is true; remove the entire else block (the warnings.warn call and its comment) from the postprocess function to eliminate dead code and keep only the enabled-path logic (references: postprocess, module.awq_lite.is_enabled, the warnings.warn line that mentions "awq_lite: Disabling for {name}...").
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 1328-1333: The else branch in postprocess is unreachable because
postprocess() is only invoked when module.awq_lite.is_enabled is true; remove
the entire else block (the warnings.warn call and its comment) from the
postprocess function to eliminate dead code and keep only the enabled-path logic
(references: postprocess, module.awq_lite.is_enabled, the warnings.warn line
that mentions "awq_lite: Disabling for {name}...").
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 70dbcbb8-9cdb-4aec-aa70-cee556e618f6
📒 Files selected for processing (4)
modelopt/torch/export/quant_utils.pymodelopt/torch/quantization/model_calib.pytests/gpu/torch/export/test_export_weight_gpu.pytests/unit/torch/quantization/plugins/test_fused_experts.py
💤 Files with no reviewable changes (1)
- modelopt/torch/export/quant_utils.py
There was a problem hiding this comment.
Claude Review Summary
Findings: 0 CRITICAL, 0 IMPORTANT, 1 SUGGESTION
Overall assessment
This is a clean, well-scoped refactor. The core change — making max_calibrate always run weight_only_quantize before the optional forward_loop — establishes a stronger invariant (every weight quantizer has _amax after max_calibrate, regardless of MoE routing) and lets several downstream band-aids be deleted:
_bootstrap_uncalibrated_weight_quantizersinmse_calibrate- per-module
max_calibratefallback inawq_lite.postprocess's disabled branch _ensure_weight_quantizer_calibratedlazy calibration during HF export
The change holds together end-to-end:
- For
mse_calibrate/local_hessian_calibrate/gptq/svdquant, the pre-existing call tomax_calibratenow naturally covers dead experts. - The new
delattr(self.weight_quantizer, "_amax")insideawq_lite.forwardcorrectly reverts to the dynamic_get_amaxpath so the per-alpha sweep recomputesmax(|weight * pre_quant_scale|). ForNVFP4StaticQuantizer,_fake_quantizefalls back tosuper()._fake_quantizewhenself.amax is None, so dynamic recomputation works there too. apply_pre_quant_scale_and_smooth(called in the enabled-postprocess path) and the disabled-module branch of the outer postprocess loop both re-populate_amaxafterward.- Export call sites (
get_weight_scaling_factor,get_weight_scaling_factor_2) rely on the new invariant; this is safe because allmtq.quantizealgorithms route throughmax_calibrate.
Test changes are consistent with the refactor (renamed dead-expert test + removal of the lazy-calibration test that no longer has a behavior to exercise).
Notable observation
- One SUGGESTION (inline): the
elsearm inawq_lite.postprocessis now unreachable — the outer loop only callspostprocesswhenmodule.awq_lite.is_enabledis True. The arm could be deleted to remove a misleading comment and the duplicated warning.
Risk
Low. Backward-compatible per the PR description, and the new invariant (_amax always populated by max_calibrate) is strictly stronger than before.
LGTM.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1560 +/- ##
==========================================
+ Coverage 76.65% 77.17% +0.51%
==========================================
Files 478 478
Lines 52408 52388 -20
==========================================
+ Hits 40172 40428 +256
+ Misses 12236 11960 -276
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
60737c0 to
238546a
Compare
|
Actionable comments posted: 0 |
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Nice simplification — removing the three band-aid paths in favor of populating _amax once in max_calibrate is the right direction, and the test rename + assertion update on _TinyMoEModel look correct. Flagging for human sign-off on a few points:
-
AWQ lite behavior change is the most fragile piece and lacks a dedicated test. The new
delattr(self.weight_quantizer, "_amax")insideawq_lite'sforward(model_calib.py L1209-1215) is the load-bearing change that keeps the per-alpha dynamic-amax path working now thatweight_only_quantizepopulates_amaxup front. The PR rationale ("postprocess overwrites for enabled modules, disabled modules early-return") is plausible —apply_pre_quant_scale_and_smooth→_apply_weight_pre_quant_scalerepopulates for enabled modules, and the newassert module.awq_lite.is_enabledplus inlinemax_calibratefallback handles disabled ones — but there is no new test exercising AWQ with dead/uncalibrated experts. The existing renamed test only covers plainmax_calibrate. Worth either adding an AWQ-MoE-with-dead-experts test or explicitly listing which existing GPU test exercises this combination. -
Export-time safety net removed without a deprecation path.
_ensure_weight_quantizer_calibratedpreviously warned and re-derived_amaxfrom the weight if export saw an uncalibrated NVFP4 quantizer. With this PR, the same situation will dereferenceweight_quantizer._amax.float() / 448.0(quant_utils.py L292, L324) and crash withAttributeError: 'NoneType' object has no attribute 'float'. Anyone loading a partial checkpoint and exporting without re-runningmtq.quantize/max_calibratewill hit this. The PR claims "backward compatible" — that's true for the documented happy path, but the failure mode for misuse changes from "warn + recover" to "hard crash". Worth confirming this is intentional and ideally raising a clearer error thanAttributeErroringet_weight_scaling_factor[_2]when_amax is None. -
Deleted
test_export_nvfp4_static_weight_dynamic_vs_static_matchwas checking two things, not just lazy calibration: (a) dynamic vs static NVFP4 export produce matching weight/scales, and (b) lazy fill from weights when_amaxis cleared. (b) is dead, agreed — but (a) is still a useful invariant. Consider keeping a trimmed version that exercises (a) without the manualdelattrsetup. -
weight_only_quantizenow runs on everymax_calibratecall, including the recursive ones in_apply_weight_pre_quant_scale,smoothquant.postprocess, the awq_lite disabled-fallback, andawqSequentialQuantizer recalibration. For these, the lambda already calibrates frommodule.weight, so we now do that work twice on the same tensor. Functionally a no-op (idempotent), but the PR doesn't note the duplicated work and large-model calibration time may tick up slightly.
None of these are blockers — the structural cleanup is good and the diff is net -200 lines. Asking a human to confirm the AWQ test coverage and the export-error UX before approval.
Additional comments (outside the PR diff):
modelopt/torch/export/quant_utils.py:292— > Bot comment.
With _ensure_weight_quantizer_calibrated gone, this weight_quantizer._amax.float() / 448.0 (and the analogous one in get_weight_scaling_factor_2) will raise AttributeError: 'NoneType' object has no attribute 'float' instead of warning + recomputing if a user exports a checkpoint that wasn't fully calibrated. Worth at least adding an explicit if weight_quantizer._amax is None: raise RuntimeError("Weight quantizer ... has no _amax; run max_calibrate before export") so the failure points at the cause, not at a generic NoneType deref.
tests/gpu/torch/export/test_export_weight_gpu.py:124— > Bot comment.
The deleted test was checking two things: (a) dynamic vs static NVFP4 export produces matching weight + scales, and (b) lazy fill from _ensure_weight_quantizer_calibrated. (b) is dead, agreed. But (a) — that the static-quantizer export path matches the dynamic one when both are properly calibrated — is still a useful invariant and isn't covered elsewhere as far as I can see. Consider keeping a trimmed version of the test that drops the delattr setup but still asserts the dynamic-vs-static export equivalence.
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review: the structural cleanup is the right direction and the unreachable else in awq_lite.postprocess plus the docstring nit on max_calibrate look addressed (:speech_balloon: author replied "Done" on the duplicated-calibration note, and the else arm has been replaced with an assert module.awq_lite.is_enabled). However, several substantive points from cjluo-nv's previous human review still appear unresolved — flagging for human sign-off:
-
AWQ dead-expert test still missing. The
delattr(self.weight_quantizer, "_amax")insideawq_lite.forward(model_calib.py L1209-1215) is the load-bearing change that keeps the per-alpha dynamic-amax path working now thatweight_only_quantizepopulates_amaxup front. The renamedtest_max_calibrate_populates_dead_expert_quantizersonly exercises plainmax_calibrate, notawq_litewith dead/uncalibrated experts. If anyone refactorsforwardand forgets thedelattr, per-alpha amax will be silently wrong with no test to catch it. Author hasn't responded to this point. -
Export-time
AttributeErrorUX not fixed. With_ensure_weight_quantizer_calibratedgone,weight_quantizer._amax.float() / 448.0inget_weight_scaling_factor[_2](quant_utils.py:288, :324 in the post-diff file) will now raiseAttributeError: 'NoneType' object has no attribute 'float'instead of the previous warn+recover for a checkpoint exported without re-runningmax_calibrate. The previous review asked for an explicitif weight_quantizer._amax is None: raise RuntimeError("... run max_calibrate before export")so the failure points at the cause rather than a generic NoneType deref. Not addressed in the current diff. -
Dropped dynamic-vs-static NVFP4 export equivalence assertion. The deleted
test_export_nvfp4_static_weight_dynamic_vs_static_matchwas checking two things: (a) dynamic vs static NVFP4 export produce matching weight/scales; (b) lazy fill from weights when_amaxis cleared. (b) is dead code now, agreed — but (a) is still a useful invariant and isn't covered elsewhere as far as I can see. Previous review suggested keeping a trimmed version that drops thedelattrsetup but still asserts (a). Not addressed.
Net direction is good (-200 lines, simpler invariant), but please confirm test coverage for AWQ dead experts and decide whether to (1) add an explicit error in get_weight_scaling_factor[_2] and (2) keep a trimmed dynamic-vs-static equivalence test before approving.
238546a to
e10bfce
Compare
|
|
||
| # For dead experts, bootstrap reads max(|weight|). Sanity-check it matches | ||
| # the actual weight tensor's per-row max (axis=0 reduces over hidden_dim). | ||
| # For dead experts, ``_amax`` comes purely from ``weight_only_quantize`` |
There was a problem hiding this comment.
Curious, what is a dead expert?
There was a problem hiding this comment.
Dead means there is no calibration data pass through it. So its weight_quantizer is also not run, and there is not _amax in the weight_quantizer.
e10bfce to
3b89895
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit/torch/quantization/test_calib.py (1)
419-420: ⚡ Quick winAvoid
.item()in tensor assertions for torch test paths.Use tensor-native comparison to avoid Python scalar extraction in this path.
Suggested change
- assert dead_q._amax.abs().max().item() == pytest.approx( - original_uncalibrated_weight.abs().max().item(), abs=1e-6 - ) + torch.testing.assert_close( + dead_q._amax.abs().max(), + original_uncalibrated_weight.abs().max(), + atol=1e-6, + rtol=0.0, + )As per coding guidelines
**/torch/**/*.py: "Keep tensor work on the GPU and avoid unnecessary CPU-GPU syncs; avoid Python scalar extraction operators liketensor.item(),float(tensor), ormin(tensor)".🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/quantization/test_calib.py` around lines 419 - 420, Replace the scalar .item()-based assertion with a tensor-native comparison to avoid CPU/GPU sync: compare the two tensors dead_q._amax.abs().max() and original_uncalibrated_weight.abs().max() directly using a tensor-aware equality check (e.g. torch.allclose or torch.testing.assert_allclose) with atol=1e-6 so the test keeps tensors on device and does not extract Python scalars.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/unit/torch/quantization/test_calib.py`:
- Around line 419-420: Replace the scalar .item()-based assertion with a
tensor-native comparison to avoid CPU/GPU sync: compare the two tensors
dead_q._amax.abs().max() and original_uncalibrated_weight.abs().max() directly
using a tensor-aware equality check (e.g. torch.allclose or
torch.testing.assert_allclose) with atol=1e-6 so the test keeps tensors on
device and does not extract Python scalars.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: cbd6496d-92ab-42b9-8bfb-2b8cdc12157c
📒 Files selected for processing (4)
modelopt/torch/quantization/model_calib.pytests/gpu/torch/export/test_export_weight_gpu.pytests/unit/torch/quantization/plugins/test_fused_experts.pytests/unit/torch/quantization/test_calib.py
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/quantization/model_calib.py
- tests/unit/torch/quantization/plugins/test_fused_experts.py
max_calibrate now always runs weight_only_quantize before the optional
forward_loop, so every weight quantizer gets _amax regardless of MoE
routing. Weight quantizers disabled by the caller (e.g. awq_lite, which
runs max_calibrate with weight quantizers disabled) are skipped by
weight_only_quantize, so the AWQ dynamic-amax path is unaffected.
With _amax guaranteed after calibration, remove two now-redundant
band-aids:
- _bootstrap_uncalibrated_weight_quantizers (re-ran weight calibration
for experts skipped by partial MoE routing); superseded by the
always-on weight_only_quantize.
- _ensure_weight_quantizer_calibrated and its helpers in export (lazy
weight calibration at scale-factor extraction time), plus the GPU
test that only exercised that lazy path.
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
3b89895 to
311f3a5
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (3)
modelopt/torch/quantization/model_calib.py (3)
109-110: 💤 Low valueConsider inlining this trivial helper.
_collect_weight_statswraps a single call toquantizer(weight)and is used only once (line 543). Inliningpartial(TensorQuantizer.__call__, weight_quantizer, weight)or a lambda directly at the call site would reduce indirection without losing clarity.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/model_calib.py` around lines 109 - 110, The helper function _collect_weight_stats is a trivial wrapper that only calls quantizer(weight); remove the _collect_weight_stats definition and replace its single use with an inline call (e.g., a lambda or partial of TensorQuantizer.__call__ bound to weight_quantizer and weight) at the call site where weight_quantizer and weight are available so there is no indirection; ensure you update any imports/usages referencing _collect_weight_stats to use the inline callable (e.g., lambda w: weight_quantizer(w) or partial(TensorQuantizer.__call__, weight_quantizer, weight)).
183-183: 💤 Low valueParameter name
modelis misleading.This function is invoked at line 542 with a
TensorQuantizer, not a full model. Renaming the parameter tomodulewould better reflect that it accepts anynn.Module(including individual quantizers), improving readability.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/model_calib.py` at line 183, The parameter name in _run_and_load_max_stats is misleading: rename the parameter from model to module in the function signature, update its type hint (keep as nn.Module) and all internal references, and update every call site (including the call that passes a TensorQuantizer) to use the new name; also update any docstring/comments referencing the old parameter name to reflect that this accepts any nn.Module (e.g., individual quantizers).
1263-1270: ⚡ Quick winClarify the double
finish_stats_collectionpattern.
max_calibrateinternally callsfinish_stats_collection(line 277), and line 1270 calls it again. The intent appears to be processing input quantizers inside the context (while weight quantizers are disabled), then processing weight quantizers outside the context — but this is subtle and undocumented. Consider adding a brief comment explaining whyfinish_stats_collectionis invoked twice and what each call processes.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/model_calib.py` around lines 1263 - 1270, The code calls finish_stats_collection twice: once indirectly via max_calibrate (which performs a finish_stats_collection for the quantizers active inside the set_quantizer_by_cfg_context where weight quantizers are disabled) and then again explicitly after exiting the context to finalize the weight quantizers; add a short clarifying comment above this block (referencing set_quantizer_by_cfg_context, max_calibrate, and finish_stats_collection) that states the first finish_stats_collection handles input-quantizer stats and distributed sync while weight quantizers are disabled, and the second call finalizes weight-quantizer stats after the context is exited.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 109-110: The helper function _collect_weight_stats is a trivial
wrapper that only calls quantizer(weight); remove the _collect_weight_stats
definition and replace its single use with an inline call (e.g., a lambda or
partial of TensorQuantizer.__call__ bound to weight_quantizer and weight) at the
call site where weight_quantizer and weight are available so there is no
indirection; ensure you update any imports/usages referencing
_collect_weight_stats to use the inline callable (e.g., lambda w:
weight_quantizer(w) or partial(TensorQuantizer.__call__, weight_quantizer,
weight)).
- Line 183: The parameter name in _run_and_load_max_stats is misleading: rename
the parameter from model to module in the function signature, update its type
hint (keep as nn.Module) and all internal references, and update every call site
(including the call that passes a TensorQuantizer) to use the new name; also
update any docstring/comments referencing the old parameter name to reflect that
this accepts any nn.Module (e.g., individual quantizers).
- Around line 1263-1270: The code calls finish_stats_collection twice: once
indirectly via max_calibrate (which performs a finish_stats_collection for the
quantizers active inside the set_quantizer_by_cfg_context where weight
quantizers are disabled) and then again explicitly after exiting the context to
finalize the weight quantizers; add a short clarifying comment above this block
(referencing set_quantizer_by_cfg_context, max_calibrate, and
finish_stats_collection) that states the first finish_stats_collection handles
input-quantizer stats and distributed sync while weight quantizers are disabled,
and the second call finalizes weight-quantizer stats after the context is
exited.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 3f52bb8e-a584-4c45-90be-0c8b12b4daac
📒 Files selected for processing (3)
modelopt/torch/export/quant_utils.pymodelopt/torch/quantization/model_calib.pytests/gpu/torch/export/test_export_weight_gpu.py
💤 Files with no reviewable changes (1)
- modelopt/torch/export/quant_utils.py
What does this PR do?
Type of change: refactor
Usage
same as before.
Testing
run unittests locally
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit
Refactor
Tests