[6078291][OMNIML-3716] Add ViT FP8/NVFP4 recipes + Torch-TRT example, wire softmax_quantizer in _QuantAttention#1569
Conversation
…_quantizer in _QuantAttention
* modelopt_recipes/huggingface/vit/ptq/{fp8,nvfp4}.yaml -- self-contained
ViT-tuned PTQ recipes targeting HuggingFace ViTForImageClassification.
Encoder Linear weights/inputs quantized; attention Q/K/V BMMs, softmax,
and per-block LayerNorm outputs at FP8; patch-embed nn.Conv2d, classifier,
and the final vit.layernorm left FP16. NVFP4 variant runs encoder Linears
in W4A4 NVFP4 (E2M1, block 16, FP8 scales) with AWQ-lite calibration.
* examples/torch_trt/ -- end-to-end Torch-TensorRT deployment example
(load HF model -> calibrate from tiny-imagenet -> mtq.quantize ->
torch_tensorrt.compile(ir="dynamo") -> benchmark). Defaults to
google/vit-large-patch16-224; --model_id + --recipe retarget any
HF model + ModelOpt PTQ recipe.
* modelopt/torch/quantization/plugins/huggingface.py -- inside
_QuantAttention._quantized_attention, the non-kitchen branch now
temporarily replaces torch.nn.functional.softmax with a wrapper that
pipes the softmax output through self.softmax_quantizer. Previously
the slot was created on every registered attention class but only
consumed by the optional Kitchen MXFP8 path, so FP8 / NVFP4 recipes
that enabled *softmax_quantizer saw it stay uncalibrated (amax=None)
and emitted no Q/DQ around softmax during ONNX / Torch-TRT export.
Short-circuits to the unwrapped call when the quantizer is disabled
(zero-overhead). SDPA-fused softmax inside the C++ kernel is unaffected.
ImageNet-1k full-50k validation accuracy on google/vit-base-patch16-224
(batch=128, 49920/50000 samples):
FP16 baseline: Top-1 81.769% Top-5 96.124%
FP8 modelopt.onnx CLI: Top-1 81.707% Top-5 96.110% (-0.062 pp)
FP8 torch path (this PR): Top-1 81.637% Top-5 96.140% (-0.132 pp)
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
…omparison * examples/torch_trt/quantize_and_compile_vit.py -> torch_tensorrt_ptq.py * Drop the latency / speedup benchmarking comparison from the script and README; the script now only verifies that the compiled-model argmax matches the fake-quant argmax on a sample input. Accuracy comparison belongs in a separate harness, not in a "quantize + compile" example. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThis PR introduces an end-to-end Torch-TensorRT deployment example for quantized ViT models, including a bug fix for softmax quantization in HuggingFace attention, two FP8/NVFP4 PTQ recipes, the example script with calibration and compilation, documentation, and integration tests. ChangesViT PTQ Torch-TensorRT Deployment
🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1569 +/- ##
==========================================
+ Coverage 69.43% 74.56% +5.13%
==========================================
Files 477 478 +1
Lines 51977 55404 +3427
==========================================
+ Hits 36090 41312 +5222
+ Misses 15887 14092 -1795
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
* tests/examples/torch_trt/test_torch_tensorrt_ptq.py -- mirrors the
tests/examples/torch_onnx/test_torch_quant_to_onnx.py pattern: invokes
the example via run_example_command, parametrizes over (fp8, nvfp4),
uses a 1-layer ViT config (--no_pretrained + --model_kwargs) so the
test completes in ~30 s per parametrized case. Two variants:
- test_torch_tensorrt_ptq[precision] -- full e2e through
torch_tensorrt.compile (importorskip on torch_tensorrt).
- test_torch_tensorrt_ptq_skip_trt[precision] -- quantize-only
smoke test, useful on hosts without torch_tensorrt installed.
* examples/torch_trt/torch_tensorrt_ptq.py:
- Add --no_pretrained + --model_kwargs flags (mirroring torch_onnx)
so the same script doubles as the test entry point.
- Force aten.cat.default into PyTorch fallback inside
compile_with_torch_tensorrt -- torch_tensorrt 2.10's cat converter
chokes on the HF ViT cls-token + patch-embedding concat (BF16:
"Got unsupported ScalarType BFloat16"; FP16: rank-(-1) TRT tensor
that crashes the downstream `embeddings + position_embeddings`
add). The cat is a tiny [1,1,H] + [1,N,H] op that runs once per
forward, so PyTorch fallback costs essentially nothing.
Verified locally: pytest tests/examples/torch_trt/test_torch_tensorrt_ptq.py
-> 4 passed in 103 s on RTX 6000 Ada.
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Drops `test_torch_tensorrt_ptq_skip_trt` -- the full `test_torch_tensorrt_ptq` variant already exercises the same mtq.quantize path and goes further (torch_tensorrt.compile). The skip-variant added duplicate CI runtime without unique coverage. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
examples/torch_trt/torch_tensorrt_ptq.py (1)
277-294: ⚡ Quick winArgmax "match" results are informational only — the example never fails on a mismatch.
fq_match/trt_matchare computed and printed but never used to exit non-zero. The e2e test docstring (test_torch_tensorrt_ptq.py, Lines 44-46) claims the CLI "exits non-zero ... if the compiled-model argmax doesn't match the fake-quant argmax", but as written the test only validates that the pipeline runs to completion. Either enforce the check here or correct the test docstring.Note: also worth deciding intent — the comparison is against
baseline_pred(BF16), while the docstring/README phrase it as matching the fake-quant argmax. If you do enforce, be cautious: a tiny--no_pretrainedmodel under NVFP4 can legitimately flip argmax on random input, so a hard gate may be flaky for that path.♻️ One option: gate the run on mismatch
trt_match = (trt_pred == baseline_pred).all().item() print(f"TRT argmax class: {trt_pred.tolist()} (matches baseline: {trt_match})") + if not trt_match: + raise SystemExit("Torch-TensorRT argmax did not match the BF16 baseline.")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/torch_trt/torch_tensorrt_ptq.py` around lines 277 - 294, The script currently computes fq_match and trt_match but doesn't fail on mismatch; update the run to enforce non-zero exit when mismatches occur: after computing fq_match/trt_match, if fq_match is False or trt_match is False, log an error with context (include fq_pred, trt_pred, baseline_pred) and call sys.exit(1). Locate the checks around fq_pred/baseline_pred and trt_pred (symbols: fq_pred, fq_match, trt_pred, trt_match, baseline_pred, ViTLogitsWrapper, compile_with_torch_tensorrt) and add the exit-on-mismatch logic (or, if you prefer the other approach, instead update the test/docstring to accurately state that mismatches are informational rather than failing).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@examples/torch_trt/torch_tensorrt_ptq.py`:
- Around line 277-294: The script currently computes fq_match and trt_match but
doesn't fail on mismatch; update the run to enforce non-zero exit when
mismatches occur: after computing fq_match/trt_match, if fq_match is False or
trt_match is False, log an error with context (include fq_pred, trt_pred,
baseline_pred) and call sys.exit(1). Locate the checks around
fq_pred/baseline_pred and trt_pred (symbols: fq_pred, fq_match, trt_pred,
trt_match, baseline_pred, ViTLogitsWrapper, compile_with_torch_tensorrt) and add
the exit-on-mismatch logic (or, if you prefer the other approach, instead update
the test/docstring to accurately state that mismatches are informational rather
than failing).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: aea38a8d-a5f8-44cf-8b02-9e7a700116ba
📒 Files selected for processing (8)
CHANGELOG.rstexamples/torch_trt/README.mdexamples/torch_trt/requirements.txtexamples/torch_trt/torch_tensorrt_ptq.pymodelopt/torch/quantization/plugins/huggingface.pymodelopt_recipes/huggingface/vit/ptq/fp8.yamlmodelopt_recipes/huggingface/vit/ptq/nvfp4.yamltests/examples/torch_trt/test_torch_tensorrt_ptq.py
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Bug-fix portion has a regression risk on hosts that have the optional kitchen package installed: _init_kitchen_attn_fn runs before the new non-kitchen FP8 softmax path and raises NotImplementedError for any non-MXFP8 softmax quantizer. The new ViT FP8/NVFP4 recipes both enable plain E4M3 FP8 on *softmax_quantizer, so on a kitchen-equipped GPU they will fail before the new patched-softmax code path is reached. The new test uses importorskip("torch_tensorrt") but does not skip when kitchen is also installed, so this gap isn't caught in CI. Otherwise the design is reasonable (this is "new example + tuned recipes + 11-line plugin fix" rather than a new subsystem; the deterministic complexity gate fired only because of directory count) and the recipes/example look correct.
| key_states = self.k_bmm_quantizer(key_states) | ||
| value_states = self.v_bmm_quantizer(value_states) | ||
| if not self.use_kitchen: | ||
| if self.softmax_quantizer.is_enabled: |
There was a problem hiding this comment.
Bot comment.
This new path is unreachable when kitchen is installed and softmax_quantizer.is_enabled is True with FP8 (E4M3) — i.e. exactly the configuration the two new ViT recipes set up. A few lines above:
if kitchen is not None and self.kitchen_attn_fn is None:
self._init_kitchen_attn_fn()and _init_kitchen_attn_fn does:
if self.softmax_quantizer.is_mxfp(8):
...
else:
raise NotImplementedError(f"softmax_quantizer not supported: {self.softmax_quantizer}")So on a host that has kitchen installed, calibrating the new huggingface/vit/ptq/fp8.yaml (or nvfp4.yaml, which also enables FP8 on *softmax_quantizer) raises NotImplementedError before this branch ever runs. The FP8/NVFP4 recipes need the kitchen path to fall through to the new wrapper-softmax path when the quantizer isn't MXFP8 — e.g. drop the raise and only set use_kitchen=True for the MXFP8 case, or skip _init_kitchen_attn_fn entirely when the quantizer is not MXFP8. As written, the new path is gated on not self.use_kitchen which is always False whenever kitchen init succeeded.
| } | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("precision", _PRECISIONS) |
There was a problem hiding this comment.
Bot comment.
Only torch_tensorrt is importorskip'd, but the new fix has a hidden coupling to whether kitchen is importable (see comment on huggingface.py). On a CI host that has both torch_tensorrt and kitchen installed, this test will fail at mtq.quantize time with NotImplementedError: softmax_quantizer not supported: ... for both parametrized cases. Worth either fixing the underlying kitchen-vs-FP8 ordering bug or adding if kitchen is not None: pytest.skip(...) so the regression mode is at least exercised in isolation.
| arg_inputs=[example_input], | ||
| min_block_size=1, | ||
| truncate_double=True, | ||
| torch_executed_ops={torch.ops.aten.cat.default}, |
There was a problem hiding this comment.
Bot comment.
Minor: this import torch_tensorrt inside compile_with_torch_tensorrt is a local import without an explanatory comment. It's defensible here (the --skip_trt flag exists specifically so users without torch_tensorrt can still run quantize-only), but a one-line comment noting that — or moving it next to --skip_trt handling — would make the intent clearer.
| # From the NVIDIA TensorRT docker image (recommended): | ||
| docker run --gpus all -it --rm -v $(pwd):/workspace -w /workspace nvcr.io/nvidia/tensorrt:26.02-py3 bash | ||
|
|
||
| pip install -U "nvidia-modelopt[torch]" |
There was a problem hiding this comment.
| pip install -U "nvidia-modelopt[torch]" | |
| pip install -U "nvidia-modelopt" |
| @@ -0,0 +1,3 @@ | |||
| datasets>=2.14.4 | |||
| torch-tensorrt>=2.4.0 | |||
| transformers>=4.40 | |||
There was a problem hiding this comment.
| transformers>=4.40 | |
| transformers>=4.56 |
There was a problem hiding this comment.
directory missing in https://github.com/NVIDIA/Model-Optimizer/blob/main/.github/CODEOWNERS
There was a problem hiding this comment.
Needs to be added to https://github.com/NVIDIA/Model-Optimizer/blob/main/.github/workflows/example_tests.yml to enable in cicd
What does this PR do?
Type of change: new feature + bug fix
Adds a Torch-TensorRT deployment path for HuggingFace ViT and closes the
modelopt-side gap that prevented
*softmax_quantizerfrom being appliedon the standard attention forward path.
New ViT PTQ recipes under
modelopt_recipes/huggingface/vit/ptq/:fp8.yaml— W8A8 per-tensor FP8 E4M3 on encoder Linear weights/inputs;attention Q/K/V BMMs + softmax output at FP8; per-block LayerNorm output
at FP8 (one shared Q/DQ feeds Q/K/V + MLP); patch-embed
nn.Conv2d,classifier, and the finalvit.layernormleft FP16. Uses maxcalibration.
nvfp4.yaml— same skip list as FP8; encoder Linear weights/inputs runNVFP4 W4A4 (E2M1, block 16, FP8 E4M3 per-block scales). Attention BMMs,
softmax, and per-block LayerNorm outputs stay at FP8 (NVFP4 too
aggressive on those narrow distributions). Uses AWQ-lite calibration.
$importof shared snippets) anduse the "specific-enable" style: narrow
parent_class+ path scopingon every enable rule, so no
enable: falsecarve-outs are needed.New example under
examples/torch_trt/:torch_tensorrt_ptq.py— single-model pipeline (load HF model,calibrate from
zh-plus/tiny-imagenet,mtq.quantize,torch_tensorrt.compile, verify the compiled-model argmax matches thefake-quant argmax). Defaults to
google/vit-large-patch16-224; pass--model_idand--recipeto target any model + recipe combination.--no_pretrained+--model_kwargsshrink the model for fast tests.README.mddocumenting the flow, the shipped recipes, hardwarerequirements, and CLI usage.
requirements.txt.Bug fix in
modelopt/torch/quantization/plugins/huggingface.py— inside_QuantAttention._quantized_attention, the non-kitchen branch nowtemporarily replaces
torch.nn.functional.softmax(via the existingreplace_functioncontext manager) with a wrapper that pipes the softmaxoutput through
self.softmax_quantizer. Previously the slot was createdon every registered attention class but only consumed by the optional
Kitchen MXFP8 flash-attention path, so FP8 / NVFP4 recipes that enabled
*softmax_quantizersaw it stay uncalibrated (amax=None) and emittedno Q/DQ around the softmax output during ONNX / Torch-TRT export. With
this fix the
softmax_quantizeris calibrated alongside the rest ofthe model, and both the modelopt ONNX exporter and
torch_tensorrt.compilepick up the Q/DQ pair. The patch short-circuits to the unwrapped call
when the quantizer is disabled (zero-overhead) and has no effect on SDPA
paths that fuse softmax inside a C++ kernel.
New e2e integration test at
tests/examples/torch_trt/test_torch_tensorrt_ptq.py— mirrors thetorch_onnxtest pattern: invokes the example throughrun_example_command, parametrizes over the two precision modes (fp8,nvfp4), uses a 1-layer ViT config (
--no_pretrained+--model_kwargs)so each parametrized case completes in under a minute.
importorskipontorch_tensorrtso the test is automatically skipped on hosts withoutthe package.
Usage
Testing
Recipes load via
modelopt.recipe.load_recipe()and passQuantizeConfigschema validation.Run
pytest tests/examples/torch_trt/test_torch_tensorrt_ptq.py→2 parametrized cases pass on RTX 6000 Ada (fp8 / nvfp4).
End-to-end on
google/vit-base-patch16-224:mtq.quantizewith the newFP8 recipe followed by
torch_tensorrt.compile(ir="dynamo")produces aTRT engine whose argmax matches the FP16 baseline.
ONNX exported from the torch path now contains Q/DQ on 12 / 12
softmax outputs (was 0 / 12 before this PR's
_QuantAttentionfix),matching the ONNX-CLI output's quantization layout.
ImageNet-1k validation accuracy on the full 49920 / 50000 samples
(batch=128) for
google/vit-base-patch16-224:Both FP8 paths land within 0.13 pp Top-1 of the FP16 baseline; Top-5 is
within 0.02 pp across all three.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/Atests/examples/torch_trt/.Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests