[None][feat] Wan 2.2: fuse NVFP4 quantization with preceding LayerNorm/AdaLN and GELU-tanh#14773
[None][feat] Wan 2.2: fuse NVFP4 quantization with preceding LayerNorm/AdaLN and GELU-tanh#14773anikaj-eng wants to merge 12 commits into
Conversation
📝 WalkthroughWalkthroughThis PR implements NVFP4 (NV floating-point 4) quantization fusion for layer normalization and activation functions in TensorRT-LLM's Wan visual generation model. It adds two fused CUDA kernels (LayerNorm with optional AdaLN and GELU-tanh) combined with FP4 quantization, exposes them as PyTorch operators, integrates them into model modules with environment-gated activation, and provides comprehensive test coverage. ChangesNVFP4 Fused Kernels and Integration
Sequence Diagram(s)sequenceDiagram
participant User
participant LayerNorm
participant FusedOp
participant WanBlock
User->>WanBlock: forward(x, scale_msa, shift_msa, seq_len_per_batch)
WanBlock->>LayerNorm: forward(x, scale_msa, shift_msa, seq_len_per_batch)
alt nvfp4_scale present and is_nvfp4=True and SM100+
LayerNorm->>LayerNorm: Reshape to 2D
LayerNorm->>FusedOp: torch.ops.trtllm.fused_layernorm_quantize(input, ln_weight, ln_bias, scale_msa, shift_msa, sf_scale, ...)
FusedOp->>LayerNorm: (output_fp4, output_sf)
LayerNorm->>LayerNorm: Wrap in Fp4QuantizedTensor
LayerNorm->>WanBlock: Fp4QuantizedTensor
else Unfused path
LayerNorm->>LayerNorm: Standard LN with optional FP32 modulation
LayerNorm->>WanBlock: torch.Tensor
end
WanBlock->>User: output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
🧹 Nitpick comments (1)
tests/unittest/_torch/modules/test_fused_activation_quant.py (1)
282-299: ⚡ Quick winCoverage is still insufficient for the scale-factor output.
This only checks
sf_fused.shape, but downstream GEMM also consumes the swizzled scale tensor. A badsf_fusedwould still pass here as long as the packed FP4 bytes happen to match, so please comparesf_fusedagainstsf_separatetoo.✅ Small coverage addition
assert fp4_fused.shape == (m, n // 2) assert sf_fused.shape == sf_separate.shape + assert torch.equal(sf_fused, sf_separate) match_rate = (fp4_fused == fp4_separate).float().mean().item()As per coding guidelines,
tests/**: Act as a QA engineer reviewing test changes and coverage for TensorRT-LLM, and say whether coverage is sufficient or insufficient with actionable follow-up.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/test_fused_activation_quant.py` around lines 282 - 299, The test currently only compares FP4 outputs but not the swizzled scale tensors; update the assertions after calling torch.ops.trtllm.fp4_quantize and torch.ops.trtllm.fused_gelu_tanh_quantize to also validate sf_fused against sf_separate (e.g., assert exact equality or use torch.allclose for float tensors) so the swizzled scale-factor output from fused_gelu_tanh_quantize is verified; reference the variables sf_fused and sf_separate (and the calls fused_gelu_tanh_quantize / fp4_quantize) and add a concise assertion checking their values/shape match.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cu`:
- Around line 360-370: The launcher validates shapes/flags but not required
tensor pointers, which can let specialized kernels (e.g., the
fusedLayerNormQuant paths selected by params.has_modulation or
params.has_ln_affine) dereference null and crash; update the dispatch
preconditions to CHECK that the input/output/scale/bias/modulation pointers
needed for each mode are non-null before launching: for params.has_modulation
ensure modulation pointer and any dependent input/output pointers are present
and for has_ln_affine ensure ln affine scale/bias pointers exist, keeping the
existing checks on params.N (kN_HARDCODED), params.seq_len_per_batch, and
params.M % params.seq_len_per_batch as-is.
In `@cpp/tensorrt_llm/thop/fusedActivationQuant.cpp`:
- Line 87: The code calls CHECK_INPUT(sf_scale, torch::kFloat32) but then reads
sf_scale[0] without verifying it has exactly one element; add a cardinality
check to ensure sf_scale.numel() == 1 and emit a clear error if not (e.g., using
CHECK or TORCH_CHECK) before any access. Update both places that consume
sf_scale[0] in fusedActivationQuant.cpp (the CHECK_INPUT(sf_scale,
torch::kFloat32) site and the other occurrence) to validate single-element
tensors and keep the existing dtype check.
In `@cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp`:
- Around line 61-62: The comment promises v1 supports only N == 5120 but the
runtime currently only checks N % 16 == 0; update the runtime validation to
enforce N == 5120 for the v1 path: locate the validation that uses TORCH_CHECK
(the divisibility-by-16 check) in the fusedLayerNormQuant v1 binding and change
it to a strict TORCH_CHECK(N == 5120, "...") or add an additional TORCH_CHECK
when version == 1 so unsupported hidden sizes cannot be routed into the fused
kernel; apply the same change for the second occurrence mentioned around the
other check.
- Line 127: The code caches device-dependent value in a function-static variable
`multiProcessorCount` by calling
`tensorrt_llm::common::getMultiProcessorCount()` once; this can be stale for
other CUDA devices—replace the static with a per-call lookup (i.e., call
`getMultiProcessorCount()` each time you need it) or implement a device-keyed
cache keyed by current device id; locate the `static int const
multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();`
declaration in fusedLayerNormQuant.cpp and change it to compute/retrieve the
value at runtime rather than as a function-static constant.
In `@tensorrt_llm/_torch/modules/layer_norm.py`:
- Around line 154-156: The fallback path must normalize and validate AdaLN
modulation tensors so unfused behavior matches the fused contract: before
calling _forward_unfused (and in the three sites that currently call it), check
that scale_msa and shift_msa are treated as a pair (either both None or both
provided); if one is missing treat as absent for both, and if provided
convert/expand 1D per-batch vectors into 2D [B, N] modulation tensors using
seq_len_per_batch so shapes match the fused expectation used by
_forward_nvfp4_fused; add a small helper or inline logic to validate
broadcastability and to reshape/repeat using seq_len_per_batch (so
scale_msa/shift_msa become explicitly shaped [B, N] or set to None) and then
pass those normalized tensors to _forward_unfused to ensure identical semantics
across fused and unfused paths.
In `@tensorrt_llm/_torch/modules/mlp.py`:
- Around line 140-148: The fused activation+quant helpers
(_fused_gelu_tanh_quant and _fused_relu2_quant) currently return a flattened 2D
packed FP4 tensor, causing downstream NVFP4LinearMethod.apply to lose
batch/sequence dims; update both helpers to reshape the packed output back to
the original prefix dims before returning by reshaping to (*x.shape[:-1],
x.shape[-1] // 2) (i.e., preserve all dims except the last which becomes half
size for packed FP4) so NVFP4LinearMethod.apply sees fp4_tensor.rank > 2 and
restores shape correctly.
In `@tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py`:
- Around line 108-111: The helper _has_sm100 currently swallows all exceptions
when calling get_sm_version() via "except Exception" — narrow this to the
specific errors get_sm_version can raise (e.g., RuntimeError or OSError) so
unrelated bugs aren't hidden; replace "except Exception" with a targeted except
(for example "except RuntimeError:" or "except (RuntimeError, OSError):") and
keep the existing return False behavior so only expected failures are handled.
- Around line 48-53: The two top-level os.environ.setdefault calls for
"TLLM_DISABLE_MPI" and "TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION" must be moved
into an autouse module-scoped pytest fixture (e.g., define a fixture like
"autouse_env_toggles" with scope='module' and autouse=True) so they don't leak
across tests; implement the fixture to save original values, set the env vars
(prefer using pytest's monkeypatch.setenv or os.environ assignment), yield, and
then restore the originals after yield (or rely on monkeypatch undo), and remove
the import-time os.environ.setdefault lines so only the fixture controls these
environment toggles for tests in this module.
---
Nitpick comments:
In `@tests/unittest/_torch/modules/test_fused_activation_quant.py`:
- Around line 282-299: The test currently only compares FP4 outputs but not the
swizzled scale tensors; update the assertions after calling
torch.ops.trtllm.fp4_quantize and torch.ops.trtllm.fused_gelu_tanh_quantize to
also validate sf_fused against sf_separate (e.g., assert exact equality or use
torch.allclose for float tensors) so the swizzled scale-factor output from
fused_gelu_tanh_quantize is verified; reference the variables sf_fused and
sf_separate (and the calls fused_gelu_tanh_quantize / fp4_quantize) and add a
concise assertion checking their values/shape match.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: fd443c38-6f92-47f2-a7b9-0f73c4c54bb8
📒 Files selected for processing (20)
cpp/tensorrt_llm/CMakeLists.txtcpp/tensorrt_llm/kernels/CMakeLists.txtcpp/tensorrt_llm/kernels/fusedActivationQuant.cucpp/tensorrt_llm/kernels/fusedActivationQuant.hcpp/tensorrt_llm/kernels/fusedLayerNormQuant/CMakeLists.txtcpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cucpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cuhcpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/fusedActivationQuant.cppcpp/tensorrt_llm/thop/fusedLayerNormQuant.cpptensorrt_llm/_torch/custom_ops/cpp_custom_ops.pytensorrt_llm/_torch/modules/layer_norm.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/modules/mlp.pytensorrt_llm/_torch/utils.pytensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.pytensorrt_llm/_torch/visual_gen/modules/attention.pytests/unittest/_torch/modules/test_fused_activation_quant.pytests/unittest/_torch/modules/test_fused_layernorm_quant.pytests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py
luyiyun1021
left a comment
There was a problem hiding this comment.
Comments from a structured review pass (CUDA kernels + shared-module impact). All non-blocking — the implementation is careful and the fused path defaults OFF — but a couple are worth addressing before merge.
Inline: 2 Concerns — unconditional NVTX on the shared LayerNorm default path (affects all non-Wan models), and a temb.ndim==4 (per-token-timestep) hard-crash that's untested — plus several nits (dead multiProcessorCount, gelu-kernel DRY vs relu2, a stale FP8 rationale in the mlp gate comment, a wrong modeling_llama.py line citation, a personal scratch path in a test, and a fused-vs-unfused modulation-precision delta).
PR-level (no single line): the headline 1.176× speedup has no committed benchmark or repro script — only the description table and the nsys A/B recipe in the _ln_quant_type docstring. For a perf-motivated PR, consider adding a small reproducible benchmark (or stating it's a one-off local nsys run with the exact config) so reviewers can verify the number.
| shift_msa, seq_len_per_batch, | ||
| nvfp4_scale) | ||
|
|
||
| with nvtx_range("LN unfused", color="grey"): |
There was a problem hiding this comment.
[Concern · perf / shared hot-path] nvtx_range() resolves to nvtx.annotate(), which emits an NVTX push/pop on every call even when no profiler is attached (tensorrt_llm/_utils.py:953-972). This forward is the shared LayerNorm used by StarCoder2, Cohere2, Qwen3VL, Kimi-K2.5, DSA, Flux and Flux2 — none of which set quantize_type, so they all take this default branch and pay the cost (2× per layer × every decode step). Sibling modules (rms_norm.py, linear.py) carry no always-on nvtx_range, and there's a Dynamo graph-break risk if a caller ever torch.compiles through this forward.
Suggest gating it: use the env-gated nvtx_range_debug(...) (TLLM_NVTX_DEBUG / TLLM_LLMAPI_ENABLE_NVTX), or only wrap when self.is_nvfp4, so the default path stays byte-identical to baseline. Same for the "LN+NVFP4 fused" label.
| ln_b = None | ||
| # Validate the row-to-batch ratio so the kernel's | ||
| # batch_idx = row / seq_len_per_batch indexing is correct. | ||
| if hs_2d.shape[0] != s_2d.shape[0] * seq_len_per_batch: |
There was a problem hiding this comment.
[Concern · robustness / coverage] This guard hard-raises for the per-token-timestep case. When temb.ndim==4 (per-patch timestep, e.g. Wan 2.2 TI2V-5B), scale_msa is [B, S, N], so s_2d = scale_msa.reshape(-1, n) → [B*S, N] while seq_len_per_batch = S; then hs_2d.shape[0] (= B*S) vs s_2d.shape[0] * seq_len_per_batch (= B*S*S) raises ValueError for any S > 1. is_nvfp4 is set purely from quant_config + SM (not from N or timestep rank), so a static-NVFP4 model driving a 2-D timestep crashes instead of falling back. It's currently unreachable (the 2-D-timestep 5B model is hidden_size ≠ 5120 and ships no NVFP4 ckpt) and the kill switch defaults OFF, but it's fragile and untested — the tests only use a 1-D timestep=[500.].
Suggest falling back to _forward_unfused when scale_msa.ndim == 3 and scale_msa.shape[1] > 1, or asserting + documenting the "1-D timestep only" V1 limitation, and adding a temb.ndim==4 test.
| # scale_msa / shift_msa typically arrive as [B, 1, N]; reshape to [B, N]. | ||
| # Cast to input dtype (callers often keep modulation in FP32 for the | ||
| # unfused path's precision; the fused kernel reads bf16/fp16). | ||
| s_2d = scale_msa.reshape(-1, n).to(hs_2d.dtype).contiguous() |
There was a problem hiding this comment.
[Nit · numerics] The fused path casts scale_msa/shift_msa to the input bf16/fp16 dtype here, but the unfused fallback (_forward_unfused, ~line 216) applies modulation in fp32. So a model that falls back (non-SM100, or no nvfp4_scale) gets fp32-precision modulation while the fused path gets bf16-precision modulation feeding the FP4 quantizer — a small, undocumented fused-vs-unfused delta. The unit-test reference deliberately downcasts modulation to match the kernel, so the cos ≥ 0.995 check doesn't guard it. Either document that the fused path intentionally uses native-dtype modulation, or keep scale_msa/shift_msa in fp32 in the kernel signature to match the unfused contract.
| // process happens to be on, then subsequent calls from other devices | ||
| // see a stale value (different B200 variants in the same node, MIG | ||
| // partitions, etc.). | ||
| int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); |
There was a problem hiding this comment.
[Nit · dead code] multiProcessorCount is computed here and threaded into invokeFusedLayerNormQuant, but the launcher's parameter is int /*multiProcessorCount*/ (unused) — the grid is a fixed dim3 grid(params.M) with no occupancy tuning. So this value is computed and discarded, and the "re-query SM count every call / static caching is unsafe" comment defends a dead value. Suggest dropping the parameter + this call (and the .cuh extern-template arg), or leaving a one-line TODO if it's kept for future occupancy tuning.
| // back to native precision (bf16/fp16) so absmax / SF / FP4 packing math is | ||
| // byte-for-byte identical to the relu2 kernel and to cvt_warp_fp16_to_fp4. | ||
| template <typename T> | ||
| __global__ void fusedGeluTanhQuantizeKernel(T const* __restrict__ input, float const* __restrict__ sfScale, |
There was a problem hiding this comment.
[Nit · DRY] fusedGeluTanhQuantizeKernel (+ its launcher) is a ~140-line near-verbatim copy of fusedRelu2QuantizeKernel, differing only in the activation call (relu2_f32 → gelu_tanh_f32); the whole NVFP4 epilogue (absmax, XOR-1 shuffle, SF swizzle, packing, padded-column handling) is duplicated, so any future SF-layout/absmax fix must be applied twice. Suggest templating on an activation functor, e.g. template <typename T, typename Act> __global__ void fusedActQuantizeKernel(...) calling Act{}(x), and instantiating both relu2 and gelu_tanh from it.
| # Parameter (linear.py:1295), but a layer that opts into dynamic quant | ||
| # at load time will reset it to None (linear.py:821). The fused kernel | ||
| # needs a real scalar tensor, so guard against the None case explicitly. | ||
| has_scale = getattr(self.down_proj, 'input_scale', None) is not None |
There was a problem hiding this comment.
[Nit · stale comment] The comment above (line 106) justifies this input_scale is not None guard via "a layer that opts into dynamic quant at load time will reset it to None (linear.py:821)" — but linear.py:821 is in the FP8 LinearMethod path, which NVFP4 never executes. For NVFP4, input_scale is always an allocated Parameter (linear.py:1295), so for the layers this gate actually controls (NVFP4 relu2/gelu) the hasattr → getattr is not None change is behavior-preserving — good, no relu2/Nemotron-H regression. Only the cited rationale is wrong for NVFP4. Suggest dropping the linear.py:821 citation, or rephrasing to "for NVFP4 input_scale is always allocated; this is purely defensive against a hypothetical future None-reset.
| silently (no ``nvfp4_scale`` attached -> LayerNorm uses _forward_unfused). | ||
|
|
||
| A/B kill switch: ``TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION`` mirrors the | ||
| flag used by Llama (modeling_llama.py:673). Default is ``"1"`` (fusion |
There was a problem hiding this comment.
[Nit] The env-var read in modeling_llama.py is at ~line 704, not 673 (673 is unrelated config setup). The flag name and default ("1") do match. Hard-coded line citations rot quickly — suggest citing just modeling_llama.py.
| candidates = [ | ||
| os.environ.get("WAN22_T2V_NVFP4_PATH"), | ||
| "/models/Wan2.2-T2V-A14B-Diffusers-NVFP4", | ||
| "/home/scratch.anikaj_libs/trunk_08042025/models/Wan2.2-T2V-A14B-Diffusers-NVFP4", |
There was a problem hiding this comment.
[Nit] This hard-codes a personal developer scratch path (/home/scratch.anikaj_libs/...) as a fallback candidate — it leaks a username into the public repo and is dead for every other contributor / CI runner. The env var WAN22_T2V_NVFP4_PATH and the /models/... bind-mount already cover legitimate resolution (and the skip reason only references those two). Suggest removing this candidate. The repo convention for shared test data is the /home/scratch.trt_llm_data_ci/llm-models/ mount.
071ec6a to
fc3e87a
Compare
|
Thanks @luyiyun1021. I rechecked the current head (
Ready for re-review. |
Add fused GELU-tanh + NVFP4 quantization kernel mirroring the relu2+NVFP4 path from PR NVIDIA#11473, and wire it into the MLP module so Wan 2.2's FFN can run a single fused activation+quant kernel after the up_proj on Blackwell (SM100+). C++ side: - Add fusedGeluTanhQuantizeKernel + invokeFusedGeluTanhQuantize in kernels/fusedActivationQuant.cu/.h - Add torch op binding trtllm::fused_gelu_tanh_quantize in thop/fusedActivationQuant.cpp Python side: - Register fake stub for trtllm::fused_gelu_tanh_quantize - Add gelu_tanh sentinel function in _torch/utils.py for identity-based activation detection in MLP - Gate MLP's _use_fused_gelu_tanh_quant on NVFP4 + non-dynamic quant + GELU-tanh activation; add NVTX ranges around the fused path Tests: - test_fused_gelu_tanh_quantize: kernel-level correctness - test_mlp_uses_fused_gelu_tanh_quant_on_static_nvfp4: integration test ensuring the MLP picks the fused path on a calibrated checkpoint The fused path activates only with a static/calibrated input_scale, so behavior is unchanged for the default dynamic-quantization configs. Signed-off-by: anikaj <anikaj@nvidia.com>
…ation
Add a fused (Ada)LayerNorm + NVFP4 quantization kernel for the three
norm sites in WanBlock (norm1/norm2/norm3) so the per-row mean/var
reduction, modulation, and the FP4 packing run in a single pass.
C++ kernel (cpp/tensorrt_llm/kernels/fusedLayerNormQuant/):
- New static library with a templated fusedLayerNormQuantKernel
parameterized on HAS_LN_AFFINE (norm2 only) and HAS_MODULATION
(norm1/norm3 AdaLN: (1 + scale_msa) * x_hat + shift_msa)
- 2-phase reduction (sum / sum-of-squares) with warp shuffle and
cross-warp shared-memory aggregation; per-block (16-elem) FP4 SF
computed inline via cvt.rn.satfinite.e2m1x2.f32 and written in the
swizzled layout expected by CUTLASS NVFP4 GEMM
- Build wiring: CMakeLists.txt, kernels/CMakeLists.txt
add_subdirectory + EXCLUDE_REGEX, top-level CMake links the new lib
Torch binding:
- thop/fusedLayerNormQuant.cpp exposes trtllm::fused_layernorm_quantize
with optional ln_weight/ln_bias and optional scale_msa/shift_msa
- thop/CMakeLists.txt updated
- cpp_custom_ops.py adds the fake stub with FP4 + SF shape inference
Python:
- layer_norm.py grows a quantize_type='nvfp4' path that routes to the
fused op on SM100+, with scale_msa/shift_msa kwargs and an
NVTX-marked unfused fallback for A/B profiling
- visual_gen/modules/attention.py: relax the ndim==3 assertion so that
Fp4QuantizedTensor inputs (from the fused LN) propagate through
- visual_gen/models/wan/transformer_wan.py:
* _ln_quant_type() helper honors TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION
as a fusion kill switch
* WanBlock.__init__ passes quantize_type=ln_quant_type to all three
norms; forward() rewrites the three norm call sites to pass
scale_msa/shift_msa and seq_len_per_batch
* WanTransformer3DModel.post_load_weights gains _try_attach_nvfp4_scale
which copies the downstream Linear input_scale onto the upstream
LayerNorm so the fused kernel has the calibrated global scale
Tests:
- test_fused_layernorm_quant.py: parametric correctness over plain LN,
affine LN, and AdaLN modulation against an FP32 PyTorch reference,
asserting >=99% byte-for-byte FP4 match on N=5120 (Wan dim).
The fused path activates only when a calibrated input_scale is
available on the downstream Linear (e.g. the ModelOpt-quantized
nvidia/Wan2.2-T2V-A14B-Diffusers-NVFP4 checkpoint), so default
dynamic-quant configs remain unchanged.
Signed-off-by: anikaj <anikaj@nvidia.com>
… hardening * linear.py: NVFP4LinearMethod.apply now flattens an Fp4QuantizedTensor whose fp4_tensor is 3D (B, S, N/2) before the GEMM and unflattens the output. Previously the 3D shortcut only covered plain torch.Tensor inputs, so the Wan 2.2 fused LayerNorm + NVFP4 quant path tripped "mat2 must be a batch of matrices" at the downstream qkv_proj. Matches the contract the plain-tensor path already had. * mlp.py: tighten the fused gelu_tanh+NVFP4 gate from hasattr(down_proj, 'input_scale') to getattr(down_proj, 'input_scale', None) is not None. NVFP4 layers that opt into dynamic quant at load time (linear.py:821) reset input_scale to None; the previous hasattr check would still fire the fused path and read uninitialized memory. The new check also benefits the pre-existing relu2 gate (Nemotron-H) without behavior change for properly-calibrated checkpoints. * test_fused_activation_quant.py: drop the bogus "no calibrated input_scale" case from test_mlp_uses_fused_gelu_tanh_quant_on_static_nvfp4. NVFP4LinearMethod.create_weights (linear.py:1295) always allocates a placeholder input_scale Parameter, so the case was unreachable without monkey-patching and was producing False asserts. The relu2-activation case below already covers "fused gelu_tanh path stays OFF when conditions are not met". Updated comments/docstring to match the new gate semantics (no setattr has_static_input_scale). Signed-off-by: anikaj <anikaj@nvidia.com>
…elOpt ckpt
Loads the ModelOpt-quantized Wan 2.2 T2V 14B checkpoint
(huggingface.co/nvidia/Wan2.2-T2V-A14B-Diffusers-NVFP4) once via a
module-scoped fixture and runs 7 assertions in ~6s on a B200:
1. test_quant_config_resolves_static_nvfp4
DiffusionModelConfig.from_pretrained picks up quant_algo=NVFP4,
group_size=16, exclude_modules, and force_dynamic_quantization=False
from the checkpoint's embedded quantization_config.
2. test_calibrated_input_scales_populated
Every unignored NVFP4 Linear has a finite non-zero input_scale after
model.load_weights(...). Guards against ModelOpt safetensors key drift
(e.g., a rename would silently leave the kernel reading uninit memory).
3. test_layernorm_nvfp4_scale_attached_on_unignored_blocks
After post_load_weights(), all 102 expected norm{1,2,3} modules
(3 norms x 34 unignored blocks) carry nvfp4_scale, and the 18 norms
in the 6 ignored blocks (0..2 + 37..39) do NOT.
4. test_mlp_fused_gate_activates
Exactly 34 MLP blocks (the unignored ones) report
_use_fused_gelu_tanh_quant == True.
5. test_layernorm_fused_path_advertised
At least 102 LayerNorms advertise is_nvfp4=True + non-None nvfp4_scale.
6. test_fused_vs_unfused_output_matches
A/B numerical regression guard. Same model, same input, two forward
passes (fused paths forced OFF via attribute toggle vs ON), assert
cosine similarity >= 0.995. Per-kernel unit tests check FP4 byte
equality vs fp4_quantize, but only an end-to-end A/B catches things
like wrong modulation cast order, bias-add drift, or swizzle-layout
mismatch that compound across 34 blocks.
7. test_forward_smoke_runs_without_error
One full forward pass on a 480p latent (1560 tokens) + 77-token text
condition produces finite, correctly-shaped output. Smoke test for
CUDA errors / shape mismatches in the production input regime.
All seven gated by @skip_if_no_checkpoint / @skip_if_no_sm100, so the
file no-ops cleanly in CI without the NVFP4 ckpt and activates
automatically on Blackwell with it mounted at the standard path
(/models/Wan2.2-T2V-A14B-Diffusers-NVFP4) or via WAN22_T2V_NVFP4_PATH.
B200 validation on cosine_similarity = 0.99756 (>= 0.995 threshold).
Wall-time A/B (separate bench, not committed): 1.176x speedup,
14.94% latency drop on the 14B forward pass.
Signed-off-by: anikaj <anikaj@nvidia.com>
Auto-formatted output of the repo's pre-commit hook suite over the 4 preceding commits in this PR. No semantic changes; only formatting deltas from clang-format, cmake-format, yapf, ruff, and ruff-format. Files touched (11): cpp/tensorrt_llm/kernels/fusedLayerNormQuant/CMakeLists.txt cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cu cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cuh cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp tensorrt_llm/_torch/modules/layer_norm.py tensorrt_llm/_torch/modules/linear.py tensorrt_llm/_torch/modules/mlp.py tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py tests/unittest/_torch/modules/test_fused_activation_quant.py tests/unittest/_torch/modules/test_fused_layernorm_quant.py tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py Signed-off-by: anikaj <anikaj@nvidia.com>
…back Addresses CodeRabbit review feedback on PR NVIDIA#14773 (8 actionable + 1 nitpick). All fixes verified via pre-commit; the only behavioral fix (NVIDIA#6 below) is also covered by a new parametric unit test. Kernel / C++ bindings --------------------- * cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cu `invokeFusedLayerNormQuant` now TLLM_CHECK_WITH_INFO's the mode-specific tensor pointers before kernel launch (`x`, `y_fp4`, `sf_out`, `sf_scale` always; `ln_weight`+`ln_bias` when `has_ln_affine`; `scale_msa`+`shift_msa` when `has_modulation`). Replaces silent illegal-memory-access on mis-wired callers with a clear runtime error. * cpp/tensorrt_llm/thop/fusedActivationQuant.cpp Both `fused_relu2_quantize` and `fused_gelu_tanh_quantize` now TORCH_CHECK(sf_scale.numel() == 1). Previously multi-element inputs were silently truncated to the first value at the kernel's `sf_scale[0]` read. * cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp Two fixes: - Enforce N == 5120 at runtime (matches the v1 kernel's `kN_HARDCODED`); was only checking divisibility by 16. - Drop `static` qualifier on `multiProcessorCount`. The function- static cache latches whichever device the first call ran on and returns a stale SM count for subsequent calls from other CUDA devices in the same process. Python modules -------------- * tensorrt_llm/_torch/modules/layer_norm.py Add `LayerNorm._validate_adaln_pair` static helper and call it at the top of `forward` so the unfused fallback enforces the same "both None or both tensors" contract that the fused kernel binding already TORCH_CHECKs. Previously one-sided modulation inputs were silently dropped by the unfused path while the fused path errored, producing different semantics across SM versions. * tensorrt_llm/_torch/modules/mlp.py `_fused_gelu_tanh_quant` and `_fused_relu2_quant` now reshape the packed FP4 output back to `(*orig_shape[:-1], orig_shape[-1] // 2)` when called with rank>2 input. Previously they returned a 2D packed tensor; `NVFP4LinearMethod.apply` only unflattens when the `Fp4QuantizedTensor.fp4_tensor` is rank>2, so a `[B, S, H]` input came back from `down_proj` as `[B*S, H]`. The residual-add in `WanBlock.forward` only happens to broadcast correctly when `B == 1`, which is what the existing integration test exercises -- the bug would surface on `B > 1`. Tests ----- * tests/unittest/_torch/modules/test_fused_activation_quant.py - New parametric test `test_fused_helpers_preserve_3d_prefix_dims` (gelu_tanh + relu2 variants) that asserts the packed FP4 output has rank 3 and shape `[B, S, H/2]` for `B > 1` input. Pins the mlp.py fix above. - Nitpick: in `test_fused_gelu_tanh_quantize_matches_separate`, also `assert torch.equal(sf_fused, sf_separate)` -- previously only the FP4 byte tensor was compared, so a buggy SF write could have slipped through with a stale match-rate >= 0.99. * tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py - Replace module-import-time `os.environ.setdefault` calls for `TLLM_DISABLE_MPI` and `TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION` with a module-scoped autouse fixture `_wan22_fusion_test_env` that saves/restores the originals. Both env vars are read at function-call time (not import time) so the fixture is sufficient; previously the setdefaults leaked into any pytest session that even just collected this file. - Narrow the `except Exception` in `_has_sm100` to `except RuntimeError`, per the TRT-LLM coding guideline that except clauses should match the smallest expected error set (`get_sm_version` raises RuntimeError when the CUDA query fails). Validation ---------- - pre-commit (clang-format, cmake-format, yapf, ruff, ruff-format, codespell, ruff-legacy, autoflake, isort): all hooks green on every modified file. - The new B>1 shape-preservation test exercises both fused helpers end-to-end on a Blackwell GPU. Signed-off-by: anikaj <anikaj@nvidia.com>
Concerns:
- layer_norm.py: NVTX wrappers were always-on, paying push/pop on
every call for the many non-NVFP4 LayerNorm users (StarCoder2,
Cohere2, Qwen3VL, Kimi-K2.5, Flux, Flux2). Restrict NVTX ranges
to the NVFP4 path so the shared default branch is byte-identical
to baseline.
- layer_norm.py: per-token-timestep (scale_msa shape [B, S, N], S>1)
can't be represented by the fused kernel's batch_idx = row /
seq_len_per_batch indexing. Detect that case and fall back to
_forward_unfused instead of raising a shape ValueError. Add a
portable Python-only test exercising the fallback.
Nits:
- layer_norm.py: document the intentional fused-vs-unfused
modulation precision delta (fused reads bf16/fp16 scale_msa,
unfused applies modulation in fp32).
- fusedActivationQuant.cu: collapse the near-verbatim relu2 and
gelu_tanh kernels into a single template <typename T, typename Act>
fusedActQuantizeKernel + invokeFusedActQuantize launcher. The
two public invoke* entry points become 1-line dispatchers; the
NVFP4 epilogue (absmax / shfl / SF swizzle / FP4 packing /
padded columns) now lives in exactly one place.
- fusedLayerNormQuant{.cpp,.cu,.cuh}: drop the dead
multiProcessorCount parameter; v1 uses a fixed dim3(M) grid
with no occupancy tuning, so the value was computed and
discarded.
- mlp.py: rephrase the input_scale gate comment - the cited
linear.py:821 reset-to-None is in the FP8 path, not the NVFP4
path NVFP4LinearMethod always allocates input_scale as a
Parameter; the guard is purely defensive.
- transformer_wan.py: drop the hard-coded modeling_llama.py:673
line citation; cite the file by name only.
- test_wan22_nvfp4_fusion_integration.py: drop the personal
/home/scratch.anikaj_libs/... fallback path; rely on
WAN22_T2V_NVFP4_PATH env var, the /models bind-mount, and the
repo-conventional /home/scratch.trt_llm_data_ci/llm-models/
location.
Docstring coverage:
Add concise docstrings to the new symbols and the touched-but-
pre-existing functions/classes flagged by the PR docstring-
coverage check (mlp.MLP/forward/_fused_*, NVFP4LinearMethod.apply,
Attention.forward, LayerNorm.__init__, WanBlock,
cpp_custom_ops.py fake stubs, test helpers). Lifts PR-diff
docstring coverage from ~51% to ~87%.
Signed-off-by: anikaj <anikaj@nvidia.com>
… docstrings Address the pre-commit CI failure on PR NVIDIA#14773 (post-rebase). - Reflow the 91-char ValueError f-string in LayerNorm.forward to fit yapf's 80-char limit for legacy files. - Rewrite first-line docstring summaries to be imperative (D401) on newly added or modified functions: _fused_relu2_quant, _fused_gelu_tanh_quant, both register_fake stubs in cpp_custom_ops.py, _forward_unfused, _forward_nvfp4_fused, _ln_quant_type, _try_attach_nvfp4_scale. - Convert multi-line docstring summaries to single-line summary + blank line + body (D205/D212/D400) for all new tests in test_wan22_nvfp4_fusion_integration.py, test_fused_layernorm_quant.py, and the new test_fused_helpers_preserve_3d_prefix_dims test. No functional changes; comments and docstrings only. Signed-off-by: anikaj <anikaj@nvidia.com>
CI pre-commit on PR NVIDIA#14773 reported yapf and ruff-format would rewrite two files. Apply the exact diff the hooks produced: - layer_norm.py: yapf prefers leading-token line breaks; reflow the per_token_modulation guard, the _forward_unfused fallback call, and the AdaLN row-count ValueError accordingly. - test_fused_layernorm_quant.py: ruff-format collapses the ln.forward call to a single line (now under 100 chars) and puts the parenthesis of the assert message on its own line. Auto-generated formatting only; no behavior change. Signed-off-by: anikaj <anikaj@nvidia.com>
Add a single-process benchmark that loads the calibrated ModelOpt
Wan2.2 T2V 14B transformer once and runs the same forward pass twice,
once with all fused-NVFP4 paths off and once with them on, each wrapped
in a distinct NVTX range so nsys can separate them cleanly.
Why a single process: keeps the safetensors load, autotuner cache, and
page cache identical between A and B, isolating the kernel-level effect
under test.
Usage (in the build container):
# quick A/B wall-time delta
python3 tests/unittest/_torch/visual_gen/bench_wan22_nvfp4_fusion.py \
--warmup 5 --iters 20
# full nsys A/B trace
nsys profile -o wan22_ab.nsys-rep \
--trace=cuda,nvtx,osrt --cuda-event-trace=true \
python3 tests/unittest/_torch/visual_gen/bench_wan22_nvfp4_fusion.py \
--warmup 5 --iters 20
nsys stats --report nvtx_sum --report nvtx_kern_sum wan22_ab.nsys-rep
The script is also a usable reproduction recipe for the PR's perf
claim (~1.176x); reviewers can run it locally on an SM100+ host.
Checkpoint resolution mirrors the integration test: prefer
``WAN22_T2V_NVFP4_PATH``, fall back to ``/models/...`` build-container
bind-mount, then to the repo-conventional
``/home/scratch.trt_llm_data_ci/llm-models/`` CI mount.
Signed-off-by: anikaj <anikaj@nvidia.com>
32e2ab0 to
2a08165
Compare
Signed-off-by: anikaj <anikaj@nvidia.com>
Signed-off-by: anikaj <anikaj@nvidia.com>
|
/bot run |
|
PR_Github #52188 [ run ] triggered by Bot. Commit: |
|
PR_Github #52188 [ run ] completed with state
|
|
/bot run |
|
PR_Github #52275 [ run ] triggered by Bot. Commit: |
|
PR_Github #52275 [ run ] completed with state
|
|
/bot run |
|
PR_Github #52398 [ run ] triggered by Bot. Commit: |
|
PR_Github #52398 [ run ] completed with state
|
|
/bot run |
|
PR_Github #52527 [ run ] triggered by Bot. Commit: |
|
PR_Github #52527 [ run ] completed with state
|
|
/bot run |
|
PR_Github #52544 [ run ] triggered by Bot. Commit: |
|
PR_Github #52544 [ run ] completed with state
|
|
/bot run |
|
PR_Github #52603 [ run ] triggered by Bot. Commit: |
| candidates = [ | ||
| os.environ.get("WAN22_T2V_NVFP4_PATH"), | ||
| "/models/Wan2.2-T2V-A14B-Diffusers-NVFP4", | ||
| "/home/scratch.trt_llm_data_ci/llm-models/Wan2.2-T2V-A14B-Diffusers-NVFP4", | ||
| ] |
There was a problem hiding this comment.
Use utils to get the model dir
| logger.info( | ||
| "Wan NVFP4 fused-LayerNorm scale attach: norm1=%d/%d, norm2=%d/%d, norm3=%d/%d", | ||
| n_attached_n1, | ||
| total_blocks, | ||
| n_attached_n2, | ||
| total_blocks, | ||
| n_attached_n3, | ||
| total_blocks, | ||
| ) |
| # kill switch on an NVFP4-configured layer (the A/B case). Plain | ||
| # LayerNorm users skip this entirely so no NVTX push/pop overhead | ||
| # leaks into the shared hot path. | ||
| with nvtx_range("LN unfused", color="grey"): |
There was a problem hiding this comment.
For such fine-grain nvtx trace, we should leave to debug code locally.
| # Kimi-K2.5, Flux, Flux2, ...). The label is what lets | ||
| # `nsys stats --report nvtx_sum` confirm the fast path fired | ||
| # per block. | ||
| with nvtx_range("LN+NVFP4 fused", color="green"): |
There was a problem hiding this comment.
For such fine-grain nvtx trace, we should leave to debug code locally.
| """ | ||
| if quant_config is None or force_dynamic_quant: | ||
| return None | ||
| if os.environ.get("TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION", "1") == "1": |
There was a problem hiding this comment.
could you make fuse_nvfp4_ln a block-level local flag, like fuse_qk_norm_rope, not in config or using env variable.
There was a problem hiding this comment.
Bench files should not be in unittest folder. Do we have bench tests in visual gen? @zhenhuaw-me @chang-l
|
PR_Github #52603 [ run ] completed with state
|
There was a problem hiding this comment.
I'm wondering we may not modify the gerenal layer norm module and expose the quantize_type / is_nvfp4 / nvfp4_scale / _forward_nvfp4_fused / AdaLN scale_msa,shift_msa,seq_len_per_batch args to it. I suggest using class like DiTLayerNorm(LayerNorm) to do the extra work here.
luyiyun1021
left a comment
There was a problem hiding this comment.
Thanks for the work. Left some comments. Could you please trim your comments seems lots of comments added to the generl module like mlp or layernorm and some are agent dev comments and verbosed.
luyiyun1021
left a comment
There was a problem hiding this comment.
Another pass focused on the MLP fused path -- structure/DRY suggestions (non-blocking) plus one gate-consistency question. The fusion approach matches the existing GatedMLP / relu2 precedent; these are about tightening the implementation to that same single-parameterized-helper style.
| not_dynamic = not getattr(self.down_proj, "force_dynamic_quantization", | ||
| False) | ||
|
|
||
| self._use_fused_gelu_tanh_quant = (has_nvfp4 and has_kernel_gelu |
There was a problem hiding this comment.
[Nit · mergeable-branches + gate inconsistency]
The two parallel _use_fused_* booleans here + the matching if/elif in forward differ only by which helper they call and the NVTX label. Consider collapsing to a single resolved op handle, mirroring how GatedMLP does the same NVFP4 activation+quant fusion (gated_mlp.py: _can_fuse_gate_up_swiglu* -> one parameterized _fused_gate_up_swiglu(x, fp4_out=...)):
# create_weights
self._fused_act_quant_op = None
if has_nvfp4 and has_scale and not_dynamic:
if self.activation is relu2 and has_kernel:
self._fused_act_quant_op = torch.ops.trtllm.fused_relu2_quantize
elif self.activation is gelu_tanh and has_kernel_gelu:
self._fused_act_quant_op = torch.ops.trtllm.fused_gelu_tanh_quantize
# forward
if self._fused_act_quant_op is not None:
with nvtx_range("act+NVFP4 fused", color="green"):
x_act = self._fused_act_quant(x_up, self._fused_act_quant_op)
else:
x_act = self.activation(x_up)Separately, the two gates are inconsistent: this gelu gate requires not_dynamic (the comment explains it guards a stale module.alpha when an Fp4QuantizedTensor is fed under dynamic quant), but the relu2 gate just above (self._use_fused_relu2_quant = ...) does not include not_dynamic. relu2 also returns an Fp4QuantizedTensor and hits the same alpha path -- should relu2 also gate on not_dynamic, or is the gelu not_dynamic over-cautious? Unifying the gates would force this to be decided in one place.
| is_sf_swizzled=True, | ||
| ) | ||
|
|
||
| def _fused_gelu_tanh_quant(self, x: torch.Tensor) -> Fp4QuantizedTensor: |
There was a problem hiding this comment.
[Nit · DRY] _fused_gelu_tanh_quant is identical to _fused_relu2_quant (just above) except for the op symbol (fused_gelu_tanh_quantize vs fused_relu2_quantize). Worth collapsing into one op-parameterized helper:
def _fused_act_quant(self, x, op):
orig = x.shape
xf = x.view(-1, x.shape[-1]).contiguous()
if xf.dtype not in (torch.float16, torch.bfloat16):
xf = xf.to(torch.bfloat16)
fp4, sf = op(xf, self.down_proj.input_scale, 16)
if len(orig) > 2:
fp4 = fp4.view(*orig[:-1], orig[-1] // 2)
return Fp4QuantizedTensor(fp4, sf, is_sf_swizzled=True)This mirrors GatedMLP._fused_gate_up_swiglu(x, fp4_out=...) (one parameterized helper). Note the CUDA side was already unified into the templated fusedActQuantizeKernel<T, Act> -- these Python helpers are the remaining duplicated half.
luyiyun1021
left a comment
There was a problem hiding this comment.
Test-convention nits: prefer the shared skip_pre_blackwell / skip_blackwell_geforce decorators (tests/unittest/utils/util.py) over hand-rolled getSMVersion() < 100 / _has_sm100() gates, so the Blackwell-version policy stays in one place. The integration-test one also surfaces a question about the < 120 (SM120) exclusion in the production is_nvfp4 gate.
| return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fp4_quantize") | ||
|
|
||
|
|
||
| skip_unless_fused_layernorm_and_fp4 = pytest.mark.skipif( |
There was a problem hiding this comment.
[Nit · follow-existing-patterns] The getSMVersion() < 100 half here re-rolls the repo-standard skip_pre_blackwell (tests/unittest/utils/util.py:113), which is what sibling kernel tests use (test_fused_moe, test_awq_quantization, test_mla_helix, test_modeling_kimi_k25, the dsa tests, ...). The op-availability AND (not fused_layernorm_quantize_available()) is a legitimate extra condition, but the SM part should compose the shared decorator instead of duplicating the < 100 literal:
from tests.unittest.utils.util import skip_pre_blackwell
@skip_pre_blackwell
@pytest.mark.skipif(
not (fused_layernorm_quantize_available() and fp4_quantize_available()),
reason="requires trtllm fused_layernorm_quantize + fp4_quantize ops",
)
def test_...():
...Same applies to the gelu gates in test_fused_activation_quant.py (skip_unless_fused_gelu_tanh_*). Keeps the Blackwell-version policy in one place so a future SM-range change doesn't have to be chased across several hand-rolled getSMVersion() < 100 literals.
| # that specific failure; anything else (e.g., ImportError) should | ||
| # propagate so it isn't masked. | ||
| return False | ||
| return 100 <= sm < 120 |
There was a problem hiding this comment.
[Nit · follow-existing-patterns + a question for the code] _has_sm100() reimplements two decorators that already exist in tests/unittest/utils/util.py: 100 <= sm is skip_pre_blackwell (:113), and the < 120 upper bound is exactly skip_blackwell_geforce (getSMVersion() == 120, :119). Consider composing them instead of a hand-rolled helper (drop-in, since skip_if_no_sm100 is used as a plain decorator):
from tests.unittest.utils.util import skip_pre_blackwell, skip_blackwell_geforce
@skip_pre_blackwell
@skip_blackwell_geforce
@skip_if_no_checkpoint
def test_...():
...The < 120 bound mirrors the production gate in layer_norm.py (is_nvfp4 is cleared unless 100 <= sm < 120), so excluding SM120 here is consistent with the code under test. But it raises a real question about the code itself: why is SM120 excluded from the fused path at all? SM120 is Blackwell and supports cvt.rn.satfinite.e2m1x2.f32, so the kernel should run there. If the exclusion is intentional (e.g., the kernel was only validated on SM100/SM10x), a one-line comment at the is_nvfp4 gate would help; if not, both the gate and this test should allow sm >= 100.
Summary by CodeRabbit
New Features
Improvements
Description
Fuses NVFP4 quantization with the layers preceding the NVFP4 GEMMs in
the Wan 2.2 T2V DiT
(
tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py),mirroring the approach in #11473 (Mamba2 / Nemotron-H NVFP4 layernorm +
activation fusion).
Problem. Each LayerNorm / GELU-tanh feeding into an NVFP4 GEMM
dispatches three CUDA kernels (norm-or-activation, cast,
fp4_quantize) before the GEMM can launch. Across the 34 unignoredWan 2.2 blocks this is a measurable kernel-launch + memory-bandwidth
overhead on B200.
Solution. Two new CUDA kernels that emit NVFP4-packed activations
and swizzled per-16-element FP8 (e4m3) scale factors directly into the
layout the CUTLASS NVFP4 GEMM kernels in
Linearexpect:fusedLayerNormQuant— LayerNorm with optional affine weight/biasand optional AdaLN modulation
LN(x) * (1 + scale_msa) + shift_msa,fused with NVFP4 quant. Feeds
attn1.qkv_proj/attn2.qkv_proj/mlp.up_proj.fusedGeluTanhQuantize— GELU-tanh + NVFP4 quant. Feeds MLPdown_proj.Both kernels consume a per-tensor calibrated
input_scale. The fusedpaths only activate when the loaded checkpoint provides such scales
(static NVFP4) and the layer is on the unignored list from the
checkpoint's
quantization_config; otherwise they fall back to theexisting unfused path (norm/activation → cast → standalone
fp4_quantize→ GEMM). No correctness or behavior change fornon-calibrated checkpoints.
Kill switches.
TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION=1forces the unfused path forall LayerNorms (debug / A-B).
_use_fused_gelu_tanh_quant, which can becleared per-module for the same purpose.
Performance. End-to-end transformer forward pass on B200, A/B with
fusion forced ON vs forced OFF (n=20, after warmup, fp16 condition,
480p latent, 1560 tokens, 77-token text condition), validated against
the ModelOpt checkpoint
nvidia/Wan2.2-T2V-A14B-Diffusers-NVFP4:nsys NVTX breakdown attributes the wins to fewer kernel launches per
norm/activation site (3 → 1).
Nsys NVTX summary (480p latent / B=1 / 5 warmup + 20 timed)
Captured one nsys run that covers both A/B configurations in a single process
(loads the 14B Wan 2.2 NVFP4 checkpoint once, then runs baseline-first /
fused-second with the same input). Raw NVTX summary from
nsys stats --report nvtx_sum:Per-iteration latency (the headline number)
The two
(fusion=off/on)ranges wrap a 20-iter timed loop each (post-warmup).Dividing total time by iteration count:
Per-kernel attribution (median per call)
The inner NVTX ranges count individual kernel-site firings, so the median per
call shows how much faster each fusion site is:
| Kernel | Median per call | Speedup at this site |
Related issue:
Related PR: #11473 (reference design this adapts).
Test Coverage
Three new test files, all green on B200 with the calibrated checkpoint
mounted via
WAN22_T2V_NVFP4_PATH:Kernel unit tests (byte-equal vs
fp4_quantizereference):tests/unittest/_torch/modules/test_fused_activation_quant.py—GELU-tanh + NVFP4 plus MLP gate logic. 8 tests.
tests/unittest/_torch/modules/test_fused_layernorm_quant.py—plain LN, affine LN, and AdaLN forms. Parametric.
End-to-end integration tests on the calibrated checkpoint:
tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py— 7 tests, ~6 s total, module-scoped fixture loads the checkpoint
once:
(
force_dynamic_quantization=False,group_size=16,exclude_moduleshonored).Linearhas a finite non-zeroinput_scaleafter
load_weights(guards against ModelOpt safetensors keydrift).
post_load_weightsattachesnvfp4_scaleto the 102 expectednorm{1,2,3}modules on the 34 unignored blocks; the 18 norms inthe 6 ignored blocks do NOT carry it.
_use_fused_gelu_tanh_quant == True.is_nvfp4 == Trueand non-Nonenvfp4_scale.(fused paths forced OFF vs ON), cosine similarity 0.99756
(≥ 0.995 threshold).
input produces finite, correctly-shaped output.
All seven gated by
@skip_if_no_checkpointand@skip_if_no_sm100,so the file no-ops cleanly in CI environments that don't have the
NVFP4 ckpt or are on pre-Blackwell hardware, and activates
automatically when both are available.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
[ x] Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.