Skip to content

[None][feat] Wan 2.2: fuse NVFP4 quantization with preceding LayerNorm/AdaLN and GELU-tanh#14773

Open
anikaj-eng wants to merge 12 commits into
NVIDIA:mainfrom
anikaj-eng:feat/wan22-nvfp4-fusion
Open

[None][feat] Wan 2.2: fuse NVFP4 quantization with preceding LayerNorm/AdaLN and GELU-tanh#14773
anikaj-eng wants to merge 12 commits into
NVIDIA:mainfrom
anikaj-eng:feat/wan22-nvfp4-fusion

Conversation

@anikaj-eng
Copy link
Copy Markdown
Collaborator

@anikaj-eng anikaj-eng commented May 30, 2026

Summary by CodeRabbit

  • New Features

    • Added fused GELU (tanh) + NVFP4 quantization operations for improved inference performance
    • Added fused LayerNorm + NVFP4 quantization with optional adaptive normalization support
    • Extended MLP and attention modules to leverage optimized quantization kernels
  • Improvements

    • Enhanced support for Blackwell (SM100) GPU architecture with optimized kernel paths

Review Change Stack

Description

Fuses NVFP4 quantization with the layers preceding the NVFP4 GEMMs in
the Wan 2.2 T2V DiT
(tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py),
mirroring the approach in #11473 (Mamba2 / Nemotron-H NVFP4 layernorm +
activation fusion).

Problem. Each LayerNorm / GELU-tanh feeding into an NVFP4 GEMM
dispatches three CUDA kernels (norm-or-activation, cast,
fp4_quantize) before the GEMM can launch. Across the 34 unignored
Wan 2.2 blocks this is a measurable kernel-launch + memory-bandwidth
overhead on B200.

Solution. Two new CUDA kernels that emit NVFP4-packed activations
and swizzled per-16-element FP8 (e4m3) scale factors directly into the
layout the CUTLASS NVFP4 GEMM kernels in Linear expect:

  • fusedLayerNormQuant — LayerNorm with optional affine weight/bias
    and optional AdaLN modulation LN(x) * (1 + scale_msa) + shift_msa,
    fused with NVFP4 quant. Feeds attn1.qkv_proj / attn2.qkv_proj /
    mlp.up_proj.
  • fusedGeluTanhQuantize — GELU-tanh + NVFP4 quant. Feeds MLP
    down_proj.

Both kernels consume a per-tensor calibrated input_scale. The fused
paths only activate when the loaded checkpoint provides such scales
(static NVFP4) and the layer is on the unignored list from the
checkpoint's quantization_config; otherwise they fall back to the
existing unfused path (norm/activation → cast → standalone
fp4_quantize → GEMM). No correctness or behavior change for
non-calibrated checkpoints.

Kill switches.

  • TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION=1 forces the unfused path for
    all LayerNorms (debug / A-B).
  • MLP fused path is gated by _use_fused_gelu_tanh_quant, which can be
    cleared per-module for the same purpose.

Performance. End-to-end transformer forward pass on B200, A/B with
fusion forced ON vs forced OFF (n=20, after warmup, fp16 condition,
480p latent, 1560 tokens, 77-token text condition), validated against
the ModelOpt checkpoint
nvidia/Wan2.2-T2V-A14B-Diffusers-NVFP4:

Path Latency (ms)
Unfused 96.17
Fused 81.80
Speedup 1.176x (-14.94%)

nsys NVTX breakdown attributes the wins to fewer kernel launches per
norm/activation site (3 → 1).

Nsys NVTX summary (480p latent / B=1 / 5 warmup + 20 timed)

Captured one nsys run that covers both A/B configurations in a single process
(loads the 14B Wan 2.2 NVFP4 checkpoint once, then runs baseline-first /
fused-second with the same input). Raw NVTX summary from nsys stats --report nvtx_sum:

 ** NVTX Range Summary (nvtx_sum):
 Time (%)  Total Time (ns)  Instances    Avg (ns)      Med (ns)     Min (ns)    Max (ns)   StdDev (ns)   Style                       Range
 --------  ---------------  ---------  ------------  ------------  ----------  ----------  -----------  -------  -----------------------------------------------------
     25.6       2634144160          1  2634144160.0  2634144160.0  2634144160  2634144160          0.0  PushPop  TensorRT-LLM:Wan22 NVFP4 baseline (fusion=off)
     25.0       2576735104          1  2576735104.0  2576735104.0  2576735104  2576735104          0.0  PushPop  TensorRT-LLM:Wan22 NVFP4 baseline (fusion=off) warmup
     22.6       2327886688          1  2327886688.0  2327886688.0  2327886688  2327886688          0.0  PushPop  TensorRT-LLM:Wan22 NVFP4 fused (fusion=on)
     17.6       1811006208       3500      517430.3      194912.0      176896   912812416   15625806.1  PushPop  TensorRT-LLM:LN unfused
      5.6        579063232          1   579063232.0   579063232.0   579063232   579063232          0.0  PushPop  TensorRT-LLM:Wan22 NVFP4 fused (fusion=on) warmup
      3.3        341910176       2550      134082.4      133008.0      104352     1045760      21148.4  PushPop  TensorRT-LLM:LN+NVFP4 fused
      0.4         38670816        850       45495.1       44512.0       37408      488992      15586.4  PushPop  TensorRT-LLM:gelu_tanh+NVFP4 fused

Per-iteration latency (the headline number)

The two (fusion=off/on) ranges wrap a 20-iter timed loop each (post-warmup).
Dividing total time by iteration count:

Configuration Latency / iter
Unfused (fusion=off) 131.7 ms
Fused (fusion=on) 116.4 ms
Speedup 1.132x
Latency drop −11.6%
(Wall-time A/B without nsys instrumentation gives ~1.18x / −14.9%; nsys adds a
constant per-launch overhead that compresses the relative win. Both are in the
same band.)

Per-kernel attribution (median per call)

The inner NVTX ranges count individual kernel-site firings, so the median per
call shows how much faster each fusion site is:
| Kernel | Median per call | Speedup at this site |

Related issue:
Related PR: #11473 (reference design this adapts).

Test Coverage

Three new test files, all green on B200 with the calibrated checkpoint
mounted via WAN22_T2V_NVFP4_PATH:

Kernel unit tests (byte-equal vs fp4_quantize reference):

  • tests/unittest/_torch/modules/test_fused_activation_quant.py
    GELU-tanh + NVFP4 plus MLP gate logic. 8 tests.
  • tests/unittest/_torch/modules/test_fused_layernorm_quant.py
    plain LN, affine LN, and AdaLN forms. Parametric.

End-to-end integration tests on the calibrated checkpoint:

  • tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py
    — 7 tests, ~6 s total, module-scoped fixture loads the checkpoint
    once:
    1. Quant config resolves to static NVFP4
      (force_dynamic_quantization=False, group_size=16,
      exclude_modules honored).
    2. Every unignored NVFP4 Linear has a finite non-zero input_scale
      after load_weights (guards against ModelOpt safetensors key
      drift).
    3. post_load_weights attaches nvfp4_scale to the 102 expected
      norm{1,2,3} modules on the 34 unignored blocks; the 18 norms in
      the 6 ignored blocks do NOT carry it.
    4. Exactly 34 MLP blocks report _use_fused_gelu_tanh_quant == True.
    5. ≥ 102 LayerNorms advertise is_nvfp4 == True and non-None
      nvfp4_scale.
    6. Numerical A/B. Same model, same input, two forward passes
      (fused paths forced OFF vs ON), cosine similarity 0.99756
      (≥ 0.995 threshold).
    7. Smoke test. Full forward pass on a 480p/1560-token/77-condition
      input produces finite, correctly-shaped output.

All seven gated by @skip_if_no_checkpoint and @skip_if_no_sm100,
so the file no-ops cleanly in CI environments that don't have the
NVFP4 ckpt or are on pre-Blackwell hardware, and activates
automatically when both are available.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • [ x] Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@anikaj-eng anikaj-eng requested review from a team as code owners May 30, 2026 02:42
@anikaj-eng anikaj-eng requested review from hyukn and yuxianq May 30, 2026 02:42
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 30, 2026

📝 Walkthrough

Walkthrough

This PR implements NVFP4 (NV floating-point 4) quantization fusion for layer normalization and activation functions in TensorRT-LLM's Wan visual generation model. It adds two fused CUDA kernels (LayerNorm with optional AdaLN and GELU-tanh) combined with FP4 quantization, exposes them as PyTorch operators, integrates them into model modules with environment-gated activation, and provides comprehensive test coverage.

Changes

NVFP4 Fused Kernels and Integration

Layer / File(s) Summary
CUDA Kernel Implementations and Headers
cpp/tensorrt_llm/kernels/fusedActivationQuant.cu, cpp/tensorrt_llm/kernels/fusedActivationQuant.h, cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cu, cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cuh, cpp/tensorrt_llm/kernels/fusedLayerNormQuant/CMakeLists.txt, cpp/tensorrt_llm/kernels/CMakeLists.txt, cpp/tensorrt_llm/thop/CMakeLists.txt, cpp/tensorrt_llm/CMakeLists.txt
Fused GELU-tanh + FP4 kernel mirrors existing relu2 quantization with tanh-approx activation. Fused LayerNorm + optional affine/AdaLN modulation + FP4 kernel with per-row reduction, warp shuffle SF packing, and SM100+ guards. CMake configures Blackwell-specific optimizations and builds static library linked into shared target.
PyTorch C++ Operator Bindings
cpp/tensorrt_llm/thop/fusedActivationQuant.cpp, cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp
Torch extension implementations register fused_gelu_tanh_quantize and fused_layernorm_quantize operators with input validation (CUDA, contiguity, rank/shape, dtype), fixed sf_vec_size=16, output allocation via computeSwizzledLayoutSFSize, dtype-dispatched kernel launch for fp16/bf16, and PyTorch dispatcher binding for CUDA backend.
Python Modules and High-Level Integration
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py, tensorrt_llm/_torch/utils.py, tensorrt_llm/_torch/modules/layer_norm.py, tensorrt_llm/_torch/modules/mlp.py, tensorrt_llm/_torch/modules/linear.py, tensorrt_llm/_torch/visual_gen/modules/attention.py
Fake tensor shape inference for both fused ops. LayerNorm gains quantize_type parameter, NVFP4 fused path gated on nvfp4_scale and SM version, AdaLN modulation support via scale_msa/shift_msa kwargs, _forward_nvfp4_fused dispatches to fused kernel. MLP adds _use_fused_gelu_tanh_quant flag and _fused_gelu_tanh_quant helper with NVTX ranges. NVFP4LinearMethod flattens Fp4QuantizedTensor inputs for GEMM. Attention accepts Fp4QuantizedTensor hidden_states.
Wan Model NVFP4 Wiring
tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
Environment-gated quantization type resolution via _ln_quant_type helper and TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION env var. WanBlock propagates quantize_type to norm1/norm2/norm3, adopts shared gelu_tanh activation, updates forward to pass AdaLN modulation params. post_load_weights attaches calibrated input_scale from downstream Linear modules to upstream LayerNorm nvfp4_scale via _try_attach_nvfp4_scale for norm1/norm2/norm3.
Unit and Integration Tests
tests/unittest/_torch/modules/test_fused_activation_quant.py, tests/unittest/_torch/modules/test_fused_layernorm_quant.py, tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py
Fused gelu_tanh kernel tested against separate gelu_tanh + fp4_quantize with ≥99% byte match. MLP heuristic tests fused flag gating on static NVFP4 vs dynamic quantization. LayerNorm tests three variants (plain, learned affine, AdaLN modulation) with reference composition. Wan 2.2 integration suite validates config resolution, input_scale population, nvfp4_scale attachment per block, MLP fused flag activation, and A/B numerical regression with ≥0.995 cosine similarity plus smoke test.

Sequence Diagram(s)

sequenceDiagram
  participant User
  participant LayerNorm
  participant FusedOp
  participant WanBlock
  User->>WanBlock: forward(x, scale_msa, shift_msa, seq_len_per_batch)
  WanBlock->>LayerNorm: forward(x, scale_msa, shift_msa, seq_len_per_batch)
  alt nvfp4_scale present and is_nvfp4=True and SM100+
    LayerNorm->>LayerNorm: Reshape to 2D
    LayerNorm->>FusedOp: torch.ops.trtllm.fused_layernorm_quantize(input, ln_weight, ln_bias, scale_msa, shift_msa, sf_scale, ...)
    FusedOp->>LayerNorm: (output_fp4, output_sf)
    LayerNorm->>LayerNorm: Wrap in Fp4QuantizedTensor
    LayerNorm->>WanBlock: Fp4QuantizedTensor
  else Unfused path
    LayerNorm->>LayerNorm: Standard LN with optional FP32 modulation
    LayerNorm->>WanBlock: torch.Tensor
  end
  WanBlock->>User: output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

  • #14772: Main NVFP4 fusion implementation matches the proposed changes—adds fusedLayerNormQuant kernel, extends fusedActivationQuant with gelu_tanh, and integrates into Wan model with environment gating and scale attachment.

Suggested labels

VisualGen

Suggested reviewers

  • chang-l
  • liji-nv
  • tcherckez-nvidia
  • brb-nv
  • karljang
  • yuxianq
  • yibinl-nvidia
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.68% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title '[None][feat] Wan 2.2: fuse NVFP4 quantization with preceding LayerNorm/AdaLN and GELU-tanh' clearly summarizes the main feature: fusing NVFP4 quantization with LayerNorm and GELU-tanh for the Wan 2.2 model, which aligns directly with the changeset's primary objective.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description comprehensively covers the problem, solution, performance validation, test coverage, and checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (1)
tests/unittest/_torch/modules/test_fused_activation_quant.py (1)

282-299: ⚡ Quick win

Coverage is still insufficient for the scale-factor output.

This only checks sf_fused.shape, but downstream GEMM also consumes the swizzled scale tensor. A bad sf_fused would still pass here as long as the packed FP4 bytes happen to match, so please compare sf_fused against sf_separate too.

✅ Small coverage addition
     assert fp4_fused.shape == (m, n // 2)
     assert sf_fused.shape == sf_separate.shape
+    assert torch.equal(sf_fused, sf_separate)
 
     match_rate = (fp4_fused == fp4_separate).float().mean().item()

As per coding guidelines, tests/**: Act as a QA engineer reviewing test changes and coverage for TensorRT-LLM, and say whether coverage is sufficient or insufficient with actionable follow-up.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/modules/test_fused_activation_quant.py` around lines
282 - 299, The test currently only compares FP4 outputs but not the swizzled
scale tensors; update the assertions after calling torch.ops.trtllm.fp4_quantize
and torch.ops.trtllm.fused_gelu_tanh_quantize to also validate sf_fused against
sf_separate (e.g., assert exact equality or use torch.allclose for float
tensors) so the swizzled scale-factor output from fused_gelu_tanh_quantize is
verified; reference the variables sf_fused and sf_separate (and the calls
fused_gelu_tanh_quantize / fp4_quantize) and add a concise assertion checking
their values/shape match.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cu`:
- Around line 360-370: The launcher validates shapes/flags but not required
tensor pointers, which can let specialized kernels (e.g., the
fusedLayerNormQuant paths selected by params.has_modulation or
params.has_ln_affine) dereference null and crash; update the dispatch
preconditions to CHECK that the input/output/scale/bias/modulation pointers
needed for each mode are non-null before launching: for params.has_modulation
ensure modulation pointer and any dependent input/output pointers are present
and for has_ln_affine ensure ln affine scale/bias pointers exist, keeping the
existing checks on params.N (kN_HARDCODED), params.seq_len_per_batch, and
params.M % params.seq_len_per_batch as-is.

In `@cpp/tensorrt_llm/thop/fusedActivationQuant.cpp`:
- Line 87: The code calls CHECK_INPUT(sf_scale, torch::kFloat32) but then reads
sf_scale[0] without verifying it has exactly one element; add a cardinality
check to ensure sf_scale.numel() == 1 and emit a clear error if not (e.g., using
CHECK or TORCH_CHECK) before any access. Update both places that consume
sf_scale[0] in fusedActivationQuant.cpp (the CHECK_INPUT(sf_scale,
torch::kFloat32) site and the other occurrence) to validate single-element
tensors and keep the existing dtype check.

In `@cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp`:
- Around line 61-62: The comment promises v1 supports only N == 5120 but the
runtime currently only checks N % 16 == 0; update the runtime validation to
enforce N == 5120 for the v1 path: locate the validation that uses TORCH_CHECK
(the divisibility-by-16 check) in the fusedLayerNormQuant v1 binding and change
it to a strict TORCH_CHECK(N == 5120, "...") or add an additional TORCH_CHECK
when version == 1 so unsupported hidden sizes cannot be routed into the fused
kernel; apply the same change for the second occurrence mentioned around the
other check.
- Line 127: The code caches device-dependent value in a function-static variable
`multiProcessorCount` by calling
`tensorrt_llm::common::getMultiProcessorCount()` once; this can be stale for
other CUDA devices—replace the static with a per-call lookup (i.e., call
`getMultiProcessorCount()` each time you need it) or implement a device-keyed
cache keyed by current device id; locate the `static int const
multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();`
declaration in fusedLayerNormQuant.cpp and change it to compute/retrieve the
value at runtime rather than as a function-static constant.

In `@tensorrt_llm/_torch/modules/layer_norm.py`:
- Around line 154-156: The fallback path must normalize and validate AdaLN
modulation tensors so unfused behavior matches the fused contract: before
calling _forward_unfused (and in the three sites that currently call it), check
that scale_msa and shift_msa are treated as a pair (either both None or both
provided); if one is missing treat as absent for both, and if provided
convert/expand 1D per-batch vectors into 2D [B, N] modulation tensors using
seq_len_per_batch so shapes match the fused expectation used by
_forward_nvfp4_fused; add a small helper or inline logic to validate
broadcastability and to reshape/repeat using seq_len_per_batch (so
scale_msa/shift_msa become explicitly shaped [B, N] or set to None) and then
pass those normalized tensors to _forward_unfused to ensure identical semantics
across fused and unfused paths.

In `@tensorrt_llm/_torch/modules/mlp.py`:
- Around line 140-148: The fused activation+quant helpers
(_fused_gelu_tanh_quant and _fused_relu2_quant) currently return a flattened 2D
packed FP4 tensor, causing downstream NVFP4LinearMethod.apply to lose
batch/sequence dims; update both helpers to reshape the packed output back to
the original prefix dims before returning by reshaping to (*x.shape[:-1],
x.shape[-1] // 2) (i.e., preserve all dims except the last which becomes half
size for packed FP4) so NVFP4LinearMethod.apply sees fp4_tensor.rank > 2 and
restores shape correctly.

In `@tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py`:
- Around line 108-111: The helper _has_sm100 currently swallows all exceptions
when calling get_sm_version() via "except Exception" — narrow this to the
specific errors get_sm_version can raise (e.g., RuntimeError or OSError) so
unrelated bugs aren't hidden; replace "except Exception" with a targeted except
(for example "except RuntimeError:" or "except (RuntimeError, OSError):") and
keep the existing return False behavior so only expected failures are handled.
- Around line 48-53: The two top-level os.environ.setdefault calls for
"TLLM_DISABLE_MPI" and "TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION" must be moved
into an autouse module-scoped pytest fixture (e.g., define a fixture like
"autouse_env_toggles" with scope='module' and autouse=True) so they don't leak
across tests; implement the fixture to save original values, set the env vars
(prefer using pytest's monkeypatch.setenv or os.environ assignment), yield, and
then restore the originals after yield (or rely on monkeypatch undo), and remove
the import-time os.environ.setdefault lines so only the fixture controls these
environment toggles for tests in this module.

---

Nitpick comments:
In `@tests/unittest/_torch/modules/test_fused_activation_quant.py`:
- Around line 282-299: The test currently only compares FP4 outputs but not the
swizzled scale tensors; update the assertions after calling
torch.ops.trtllm.fp4_quantize and torch.ops.trtllm.fused_gelu_tanh_quantize to
also validate sf_fused against sf_separate (e.g., assert exact equality or use
torch.allclose for float tensors) so the swizzled scale-factor output from
fused_gelu_tanh_quantize is verified; reference the variables sf_fused and
sf_separate (and the calls fused_gelu_tanh_quantize / fp4_quantize) and add a
concise assertion checking their values/shape match.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: fd443c38-6f92-47f2-a7b9-0f73c4c54bb8

📥 Commits

Reviewing files that changed from the base of the PR and between 74d7c3a and fd1369f.

📒 Files selected for processing (20)
  • cpp/tensorrt_llm/CMakeLists.txt
  • cpp/tensorrt_llm/kernels/CMakeLists.txt
  • cpp/tensorrt_llm/kernels/fusedActivationQuant.cu
  • cpp/tensorrt_llm/kernels/fusedActivationQuant.h
  • cpp/tensorrt_llm/kernels/fusedLayerNormQuant/CMakeLists.txt
  • cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cu
  • cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cuh
  • cpp/tensorrt_llm/thop/CMakeLists.txt
  • cpp/tensorrt_llm/thop/fusedActivationQuant.cpp
  • cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • tensorrt_llm/_torch/modules/layer_norm.py
  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/modules/mlp.py
  • tensorrt_llm/_torch/utils.py
  • tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
  • tensorrt_llm/_torch/visual_gen/modules/attention.py
  • tests/unittest/_torch/modules/test_fused_activation_quant.py
  • tests/unittest/_torch/modules/test_fused_layernorm_quant.py
  • tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py

Comment thread cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cu
Comment thread cpp/tensorrt_llm/thop/fusedActivationQuant.cpp
Comment thread cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp
Comment thread cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp Outdated
Comment thread tensorrt_llm/_torch/modules/layer_norm.py Outdated
Comment thread tensorrt_llm/_torch/modules/mlp.py
Comment thread tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py Outdated
Copy link
Copy Markdown
Collaborator

@luyiyun1021 luyiyun1021 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments from a structured review pass (CUDA kernels + shared-module impact). All non-blocking — the implementation is careful and the fused path defaults OFF — but a couple are worth addressing before merge.

Inline: 2 Concerns — unconditional NVTX on the shared LayerNorm default path (affects all non-Wan models), and a temb.ndim==4 (per-token-timestep) hard-crash that's untested — plus several nits (dead multiProcessorCount, gelu-kernel DRY vs relu2, a stale FP8 rationale in the mlp gate comment, a wrong modeling_llama.py line citation, a personal scratch path in a test, and a fused-vs-unfused modulation-precision delta).

PR-level (no single line): the headline 1.176× speedup has no committed benchmark or repro script — only the description table and the nsys A/B recipe in the _ln_quant_type docstring. For a perf-motivated PR, consider adding a small reproducible benchmark (or stating it's a one-off local nsys run with the exact config) so reviewers can verify the number.

shift_msa, seq_len_per_batch,
nvfp4_scale)

with nvtx_range("LN unfused", color="grey"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Concern · perf / shared hot-path] nvtx_range() resolves to nvtx.annotate(), which emits an NVTX push/pop on every call even when no profiler is attached (tensorrt_llm/_utils.py:953-972). This forward is the shared LayerNorm used by StarCoder2, Cohere2, Qwen3VL, Kimi-K2.5, DSA, Flux and Flux2 — none of which set quantize_type, so they all take this default branch and pay the cost (2× per layer × every decode step). Sibling modules (rms_norm.py, linear.py) carry no always-on nvtx_range, and there's a Dynamo graph-break risk if a caller ever torch.compiles through this forward.

Suggest gating it: use the env-gated nvtx_range_debug(...) (TLLM_NVTX_DEBUG / TLLM_LLMAPI_ENABLE_NVTX), or only wrap when self.is_nvfp4, so the default path stays byte-identical to baseline. Same for the "LN+NVFP4 fused" label.

ln_b = None
# Validate the row-to-batch ratio so the kernel's
# batch_idx = row / seq_len_per_batch indexing is correct.
if hs_2d.shape[0] != s_2d.shape[0] * seq_len_per_batch:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Concern · robustness / coverage] This guard hard-raises for the per-token-timestep case. When temb.ndim==4 (per-patch timestep, e.g. Wan 2.2 TI2V-5B), scale_msa is [B, S, N], so s_2d = scale_msa.reshape(-1, n)[B*S, N] while seq_len_per_batch = S; then hs_2d.shape[0] (= B*S) vs s_2d.shape[0] * seq_len_per_batch (= B*S*S) raises ValueError for any S > 1. is_nvfp4 is set purely from quant_config + SM (not from N or timestep rank), so a static-NVFP4 model driving a 2-D timestep crashes instead of falling back. It's currently unreachable (the 2-D-timestep 5B model is hidden_size ≠ 5120 and ships no NVFP4 ckpt) and the kill switch defaults OFF, but it's fragile and untested — the tests only use a 1-D timestep=[500.].

Suggest falling back to _forward_unfused when scale_msa.ndim == 3 and scale_msa.shape[1] > 1, or asserting + documenting the "1-D timestep only" V1 limitation, and adding a temb.ndim==4 test.

# scale_msa / shift_msa typically arrive as [B, 1, N]; reshape to [B, N].
# Cast to input dtype (callers often keep modulation in FP32 for the
# unfused path's precision; the fused kernel reads bf16/fp16).
s_2d = scale_msa.reshape(-1, n).to(hs_2d.dtype).contiguous()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit · numerics] The fused path casts scale_msa/shift_msa to the input bf16/fp16 dtype here, but the unfused fallback (_forward_unfused, ~line 216) applies modulation in fp32. So a model that falls back (non-SM100, or no nvfp4_scale) gets fp32-precision modulation while the fused path gets bf16-precision modulation feeding the FP4 quantizer — a small, undocumented fused-vs-unfused delta. The unit-test reference deliberately downcasts modulation to match the kernel, so the cos ≥ 0.995 check doesn't guard it. Either document that the fused path intentionally uses native-dtype modulation, or keep scale_msa/shift_msa in fp32 in the kernel signature to match the unfused contract.

// process happens to be on, then subsequent calls from other devices
// see a stale value (different B200 variants in the same node, MIG
// partitions, etc.).
int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit · dead code] multiProcessorCount is computed here and threaded into invokeFusedLayerNormQuant, but the launcher's parameter is int /*multiProcessorCount*/ (unused) — the grid is a fixed dim3 grid(params.M) with no occupancy tuning. So this value is computed and discarded, and the "re-query SM count every call / static caching is unsafe" comment defends a dead value. Suggest dropping the parameter + this call (and the .cuh extern-template arg), or leaving a one-line TODO if it's kept for future occupancy tuning.

// back to native precision (bf16/fp16) so absmax / SF / FP4 packing math is
// byte-for-byte identical to the relu2 kernel and to cvt_warp_fp16_to_fp4.
template <typename T>
__global__ void fusedGeluTanhQuantizeKernel(T const* __restrict__ input, float const* __restrict__ sfScale,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit · DRY] fusedGeluTanhQuantizeKernel (+ its launcher) is a ~140-line near-verbatim copy of fusedRelu2QuantizeKernel, differing only in the activation call (relu2_f32gelu_tanh_f32); the whole NVFP4 epilogue (absmax, XOR-1 shuffle, SF swizzle, packing, padded-column handling) is duplicated, so any future SF-layout/absmax fix must be applied twice. Suggest templating on an activation functor, e.g. template <typename T, typename Act> __global__ void fusedActQuantizeKernel(...) calling Act{}(x), and instantiating both relu2 and gelu_tanh from it.

# Parameter (linear.py:1295), but a layer that opts into dynamic quant
# at load time will reset it to None (linear.py:821). The fused kernel
# needs a real scalar tensor, so guard against the None case explicitly.
has_scale = getattr(self.down_proj, 'input_scale', None) is not None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit · stale comment] The comment above (line 106) justifies this input_scale is not None guard via "a layer that opts into dynamic quant at load time will reset it to None (linear.py:821)" — but linear.py:821 is in the FP8 LinearMethod path, which NVFP4 never executes. For NVFP4, input_scale is always an allocated Parameter (linear.py:1295), so for the layers this gate actually controls (NVFP4 relu2/gelu) the hasattr → getattr is not None change is behavior-preserving — good, no relu2/Nemotron-H regression. Only the cited rationale is wrong for NVFP4. Suggest dropping the linear.py:821 citation, or rephrasing to "for NVFP4 input_scale is always allocated; this is purely defensive against a hypothetical future None-reset.

silently (no ``nvfp4_scale`` attached -> LayerNorm uses _forward_unfused).

A/B kill switch: ``TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION`` mirrors the
flag used by Llama (modeling_llama.py:673). Default is ``"1"`` (fusion
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit] The env-var read in modeling_llama.py is at ~line 704, not 673 (673 is unrelated config setup). The flag name and default ("1") do match. Hard-coded line citations rot quickly — suggest citing just modeling_llama.py.

candidates = [
os.environ.get("WAN22_T2V_NVFP4_PATH"),
"/models/Wan2.2-T2V-A14B-Diffusers-NVFP4",
"/home/scratch.anikaj_libs/trunk_08042025/models/Wan2.2-T2V-A14B-Diffusers-NVFP4",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit] This hard-codes a personal developer scratch path (/home/scratch.anikaj_libs/...) as a fallback candidate — it leaks a username into the public repo and is dead for every other contributor / CI runner. The env var WAN22_T2V_NVFP4_PATH and the /models/... bind-mount already cover legitimate resolution (and the skip reason only references those two). Suggest removing this candidate. The repo convention for shared test data is the /home/scratch.trt_llm_data_ci/llm-models/ mount.

@anikaj-eng anikaj-eng force-pushed the feat/wan22-nvfp4-fusion branch from 071ec6a to fc3e87a Compare June 3, 2026 00:37
@anikaj-eng anikaj-eng requested a review from a team as a code owner June 3, 2026 18:22
Copy link
Copy Markdown
Collaborator Author

Thanks @luyiyun1021. I rechecked the current head (32e2ab08) against the review feedback; the comments are addressed:

  • CodeRabbit batch: added mode-specific fused LayerNorm pointer checks, sf_scale.numel() == 1 validation for both fused activation ops, strict N == 5120 validation, AdaLN pair validation, rank-preserving FP4 helper outputs, scoped env-var fixture setup, narrowed _has_sm100() exception handling, and sf_fused vs sf_separate coverage.
  • Shared LayerNorm hot path: NVTX ranges are now only emitted for NVFP4-configured layers, so default LayerNorm callers no longer pay unconditional push/pop overhead.
  • Per-token timestep modulation: [B, S, N] modulation with S > 1 now falls back to _forward_unfused instead of hitting the fused kernel shape guard, with a Python-only regression test.
  • Structured-review nits: modulation precision delta is documented, dead multiProcessorCount plumbing is removed, relu2/GELU-tanh share the templated fusedActQuantizeKernel, the MLP input_scale comment is corrected, the hard-coded modeling_llama.py line citation is removed, and the personal scratch path fallback was replaced with the repo-conventional CI model path.
  • PR-level perf/docstring notes: the PR description now calls out the B200 A/B/nsys run as the exact measured config, and docstrings were added for the PR diff coverage warning.

Ready for re-review.

anikaj-eng added 10 commits June 4, 2026 23:20
Add fused GELU-tanh + NVFP4 quantization kernel mirroring the
relu2+NVFP4 path from PR NVIDIA#11473, and wire it into the MLP module so
Wan 2.2's FFN can run a single fused activation+quant kernel after
the up_proj on Blackwell (SM100+).

C++ side:
- Add fusedGeluTanhQuantizeKernel + invokeFusedGeluTanhQuantize in
  kernels/fusedActivationQuant.cu/.h
- Add torch op binding trtllm::fused_gelu_tanh_quantize in
  thop/fusedActivationQuant.cpp

Python side:
- Register fake stub for trtllm::fused_gelu_tanh_quantize
- Add gelu_tanh sentinel function in _torch/utils.py for identity-based
  activation detection in MLP
- Gate MLP's _use_fused_gelu_tanh_quant on NVFP4 + non-dynamic quant +
  GELU-tanh activation; add NVTX ranges around the fused path

Tests:
- test_fused_gelu_tanh_quantize: kernel-level correctness
- test_mlp_uses_fused_gelu_tanh_quant_on_static_nvfp4: integration test
  ensuring the MLP picks the fused path on a calibrated checkpoint

The fused path activates only with a static/calibrated input_scale, so
behavior is unchanged for the default dynamic-quantization configs.

Signed-off-by: anikaj <anikaj@nvidia.com>
…ation

Add a fused (Ada)LayerNorm + NVFP4 quantization kernel for the three
norm sites in WanBlock (norm1/norm2/norm3) so the per-row mean/var
reduction, modulation, and the FP4 packing run in a single pass.

C++ kernel (cpp/tensorrt_llm/kernels/fusedLayerNormQuant/):
- New static library with a templated fusedLayerNormQuantKernel
  parameterized on HAS_LN_AFFINE (norm2 only) and HAS_MODULATION
  (norm1/norm3 AdaLN: (1 + scale_msa) * x_hat + shift_msa)
- 2-phase reduction (sum / sum-of-squares) with warp shuffle and
  cross-warp shared-memory aggregation; per-block (16-elem) FP4 SF
  computed inline via cvt.rn.satfinite.e2m1x2.f32 and written in the
  swizzled layout expected by CUTLASS NVFP4 GEMM
- Build wiring: CMakeLists.txt, kernels/CMakeLists.txt
  add_subdirectory + EXCLUDE_REGEX, top-level CMake links the new lib

Torch binding:
- thop/fusedLayerNormQuant.cpp exposes trtllm::fused_layernorm_quantize
  with optional ln_weight/ln_bias and optional scale_msa/shift_msa
- thop/CMakeLists.txt updated
- cpp_custom_ops.py adds the fake stub with FP4 + SF shape inference

Python:
- layer_norm.py grows a quantize_type='nvfp4' path that routes to the
  fused op on SM100+, with scale_msa/shift_msa kwargs and an
  NVTX-marked unfused fallback for A/B profiling
- visual_gen/modules/attention.py: relax the ndim==3 assertion so that
  Fp4QuantizedTensor inputs (from the fused LN) propagate through
- visual_gen/models/wan/transformer_wan.py:
    * _ln_quant_type() helper honors TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION
      as a fusion kill switch
    * WanBlock.__init__ passes quantize_type=ln_quant_type to all three
      norms; forward() rewrites the three norm call sites to pass
      scale_msa/shift_msa and seq_len_per_batch
    * WanTransformer3DModel.post_load_weights gains _try_attach_nvfp4_scale
      which copies the downstream Linear input_scale onto the upstream
      LayerNorm so the fused kernel has the calibrated global scale

Tests:
- test_fused_layernorm_quant.py: parametric correctness over plain LN,
  affine LN, and AdaLN modulation against an FP32 PyTorch reference,
  asserting >=99% byte-for-byte FP4 match on N=5120 (Wan dim).

The fused path activates only when a calibrated input_scale is
available on the downstream Linear (e.g. the ModelOpt-quantized
nvidia/Wan2.2-T2V-A14B-Diffusers-NVFP4 checkpoint), so default
dynamic-quant configs remain unchanged.

Signed-off-by: anikaj <anikaj@nvidia.com>
… hardening

* linear.py: NVFP4LinearMethod.apply now flattens an Fp4QuantizedTensor
  whose fp4_tensor is 3D (B, S, N/2) before the GEMM and unflattens the
  output. Previously the 3D shortcut only covered plain torch.Tensor
  inputs, so the Wan 2.2 fused LayerNorm + NVFP4 quant path tripped
  "mat2 must be a batch of matrices" at the downstream qkv_proj. Matches
  the contract the plain-tensor path already had.

* mlp.py: tighten the fused gelu_tanh+NVFP4 gate from
  hasattr(down_proj, 'input_scale') to
  getattr(down_proj, 'input_scale', None) is not None. NVFP4 layers that
  opt into dynamic quant at load time (linear.py:821) reset input_scale
  to None; the previous hasattr check would still fire the fused path
  and read uninitialized memory. The new check also benefits the
  pre-existing relu2 gate (Nemotron-H) without behavior change for
  properly-calibrated checkpoints.

* test_fused_activation_quant.py: drop the bogus "no calibrated
  input_scale" case from test_mlp_uses_fused_gelu_tanh_quant_on_static_nvfp4.
  NVFP4LinearMethod.create_weights (linear.py:1295) always allocates a
  placeholder input_scale Parameter, so the case was unreachable without
  monkey-patching and was producing False asserts. The relu2-activation
  case below already covers "fused gelu_tanh path stays OFF when
  conditions are not met". Updated comments/docstring to match the new
  gate semantics (no setattr has_static_input_scale).

Signed-off-by: anikaj <anikaj@nvidia.com>
…elOpt ckpt

Loads the ModelOpt-quantized Wan 2.2 T2V 14B checkpoint
(huggingface.co/nvidia/Wan2.2-T2V-A14B-Diffusers-NVFP4) once via a
module-scoped fixture and runs 7 assertions in ~6s on a B200:

1. test_quant_config_resolves_static_nvfp4
   DiffusionModelConfig.from_pretrained picks up quant_algo=NVFP4,
   group_size=16, exclude_modules, and force_dynamic_quantization=False
   from the checkpoint's embedded quantization_config.

2. test_calibrated_input_scales_populated
   Every unignored NVFP4 Linear has a finite non-zero input_scale after
   model.load_weights(...). Guards against ModelOpt safetensors key drift
   (e.g., a rename would silently leave the kernel reading uninit memory).

3. test_layernorm_nvfp4_scale_attached_on_unignored_blocks
   After post_load_weights(), all 102 expected norm{1,2,3} modules
   (3 norms x 34 unignored blocks) carry nvfp4_scale, and the 18 norms
   in the 6 ignored blocks (0..2 + 37..39) do NOT.

4. test_mlp_fused_gate_activates
   Exactly 34 MLP blocks (the unignored ones) report
   _use_fused_gelu_tanh_quant == True.

5. test_layernorm_fused_path_advertised
   At least 102 LayerNorms advertise is_nvfp4=True + non-None nvfp4_scale.

6. test_fused_vs_unfused_output_matches
   A/B numerical regression guard. Same model, same input, two forward
   passes (fused paths forced OFF via attribute toggle vs ON), assert
   cosine similarity >= 0.995. Per-kernel unit tests check FP4 byte
   equality vs fp4_quantize, but only an end-to-end A/B catches things
   like wrong modulation cast order, bias-add drift, or swizzle-layout
   mismatch that compound across 34 blocks.

7. test_forward_smoke_runs_without_error
   One full forward pass on a 480p latent (1560 tokens) + 77-token text
   condition produces finite, correctly-shaped output. Smoke test for
   CUDA errors / shape mismatches in the production input regime.

All seven gated by @skip_if_no_checkpoint / @skip_if_no_sm100, so the
file no-ops cleanly in CI without the NVFP4 ckpt and activates
automatically on Blackwell with it mounted at the standard path
(/models/Wan2.2-T2V-A14B-Diffusers-NVFP4) or via WAN22_T2V_NVFP4_PATH.

B200 validation on cosine_similarity = 0.99756 (>= 0.995 threshold).
Wall-time A/B (separate bench, not committed): 1.176x speedup,
14.94% latency drop on the 14B forward pass.

Signed-off-by: anikaj <anikaj@nvidia.com>
Auto-formatted output of the repo's pre-commit hook suite over the 4
preceding commits in this PR. No semantic changes; only formatting
deltas from clang-format, cmake-format, yapf, ruff, and ruff-format.

Files touched (11):
  cpp/tensorrt_llm/kernels/fusedLayerNormQuant/CMakeLists.txt
  cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cu
  cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cuh
  cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp
  tensorrt_llm/_torch/modules/layer_norm.py
  tensorrt_llm/_torch/modules/linear.py
  tensorrt_llm/_torch/modules/mlp.py
  tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
  tests/unittest/_torch/modules/test_fused_activation_quant.py
  tests/unittest/_torch/modules/test_fused_layernorm_quant.py
  tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py

Signed-off-by: anikaj <anikaj@nvidia.com>
…back

Addresses CodeRabbit review feedback on PR NVIDIA#14773 (8 actionable + 1
nitpick). All fixes verified via pre-commit; the only behavioral fix
(NVIDIA#6 below) is also covered by a new parametric unit test.

Kernel / C++ bindings
---------------------
* cpp/tensorrt_llm/kernels/fusedLayerNormQuant/fusedLayerNormQuant.cu
  `invokeFusedLayerNormQuant` now TLLM_CHECK_WITH_INFO's the
  mode-specific tensor pointers before kernel launch (`x`, `y_fp4`,
  `sf_out`, `sf_scale` always; `ln_weight`+`ln_bias` when
  `has_ln_affine`; `scale_msa`+`shift_msa` when `has_modulation`).
  Replaces silent illegal-memory-access on mis-wired callers with a
  clear runtime error.

* cpp/tensorrt_llm/thop/fusedActivationQuant.cpp
  Both `fused_relu2_quantize` and `fused_gelu_tanh_quantize` now
  TORCH_CHECK(sf_scale.numel() == 1). Previously multi-element inputs
  were silently truncated to the first value at the kernel's
  `sf_scale[0]` read.

* cpp/tensorrt_llm/thop/fusedLayerNormQuant.cpp
  Two fixes:
  - Enforce N == 5120 at runtime (matches the v1 kernel's
    `kN_HARDCODED`); was only checking divisibility by 16.
  - Drop `static` qualifier on `multiProcessorCount`. The function-
    static cache latches whichever device the first call ran on and
    returns a stale SM count for subsequent calls from other CUDA
    devices in the same process.

Python modules
--------------
* tensorrt_llm/_torch/modules/layer_norm.py
  Add `LayerNorm._validate_adaln_pair` static helper and call it at
  the top of `forward` so the unfused fallback enforces the same
  "both None or both tensors" contract that the fused kernel binding
  already TORCH_CHECKs. Previously one-sided modulation inputs were
  silently dropped by the unfused path while the fused path errored,
  producing different semantics across SM versions.

* tensorrt_llm/_torch/modules/mlp.py
  `_fused_gelu_tanh_quant` and `_fused_relu2_quant` now reshape the
  packed FP4 output back to `(*orig_shape[:-1], orig_shape[-1] // 2)`
  when called with rank>2 input. Previously they returned a 2D packed
  tensor; `NVFP4LinearMethod.apply` only unflattens when the
  `Fp4QuantizedTensor.fp4_tensor` is rank>2, so a `[B, S, H]` input
  came back from `down_proj` as `[B*S, H]`. The residual-add in
  `WanBlock.forward` only happens to broadcast correctly when `B == 1`,
  which is what the existing integration test exercises -- the bug
  would surface on `B > 1`.

Tests
-----
* tests/unittest/_torch/modules/test_fused_activation_quant.py
  - New parametric test `test_fused_helpers_preserve_3d_prefix_dims`
    (gelu_tanh + relu2 variants) that asserts the packed FP4 output
    has rank 3 and shape `[B, S, H/2]` for `B > 1` input. Pins the
    mlp.py fix above.
  - Nitpick: in `test_fused_gelu_tanh_quantize_matches_separate`,
    also `assert torch.equal(sf_fused, sf_separate)` -- previously
    only the FP4 byte tensor was compared, so a buggy SF write could
    have slipped through with a stale match-rate >= 0.99.

* tests/unittest/_torch/visual_gen/test_wan22_nvfp4_fusion_integration.py
  - Replace module-import-time `os.environ.setdefault` calls for
    `TLLM_DISABLE_MPI` and `TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION`
    with a module-scoped autouse fixture
    `_wan22_fusion_test_env` that saves/restores the originals.
    Both env vars are read at function-call time (not import time)
    so the fixture is sufficient; previously the setdefaults leaked
    into any pytest session that even just collected this file.
  - Narrow the `except Exception` in `_has_sm100` to `except
    RuntimeError`, per the TRT-LLM coding guideline that except
    clauses should match the smallest expected error set
    (`get_sm_version` raises RuntimeError when the CUDA query fails).

Validation
----------
- pre-commit (clang-format, cmake-format, yapf, ruff, ruff-format,
  codespell, ruff-legacy, autoflake, isort): all hooks green on
  every modified file.
- The new B>1 shape-preservation test exercises both fused helpers
  end-to-end on a Blackwell GPU.

Signed-off-by: anikaj <anikaj@nvidia.com>
Concerns:
  - layer_norm.py: NVTX wrappers were always-on, paying push/pop on
    every call for the many non-NVFP4 LayerNorm users (StarCoder2,
    Cohere2, Qwen3VL, Kimi-K2.5, Flux, Flux2). Restrict NVTX ranges
    to the NVFP4 path so the shared default branch is byte-identical
    to baseline.
  - layer_norm.py: per-token-timestep (scale_msa shape [B, S, N], S>1)
    can't be represented by the fused kernel's batch_idx = row /
    seq_len_per_batch indexing. Detect that case and fall back to
    _forward_unfused instead of raising a shape ValueError. Add a
    portable Python-only test exercising the fallback.

Nits:
  - layer_norm.py: document the intentional fused-vs-unfused
    modulation precision delta (fused reads bf16/fp16 scale_msa,
    unfused applies modulation in fp32).
  - fusedActivationQuant.cu: collapse the near-verbatim relu2 and
    gelu_tanh kernels into a single template <typename T, typename Act>
    fusedActQuantizeKernel + invokeFusedActQuantize launcher. The
    two public invoke* entry points become 1-line dispatchers; the
    NVFP4 epilogue (absmax / shfl / SF swizzle / FP4 packing /
    padded columns) now lives in exactly one place.
  - fusedLayerNormQuant{.cpp,.cu,.cuh}: drop the dead
    multiProcessorCount parameter; v1 uses a fixed dim3(M) grid
    with no occupancy tuning, so the value was computed and
    discarded.
  - mlp.py: rephrase the input_scale gate comment - the cited
    linear.py:821 reset-to-None is in the FP8 path, not the NVFP4
    path NVFP4LinearMethod always allocates input_scale as a
    Parameter; the guard is purely defensive.
  - transformer_wan.py: drop the hard-coded modeling_llama.py:673
    line citation; cite the file by name only.
  - test_wan22_nvfp4_fusion_integration.py: drop the personal
    /home/scratch.anikaj_libs/... fallback path; rely on
    WAN22_T2V_NVFP4_PATH env var, the /models bind-mount, and the
    repo-conventional /home/scratch.trt_llm_data_ci/llm-models/
    location.

Docstring coverage:
  Add concise docstrings to the new symbols and the touched-but-
  pre-existing functions/classes flagged by the PR docstring-
  coverage check (mlp.MLP/forward/_fused_*, NVFP4LinearMethod.apply,
  Attention.forward, LayerNorm.__init__, WanBlock,
  cpp_custom_ops.py fake stubs, test helpers). Lifts PR-diff
  docstring coverage from ~51% to ~87%.

Signed-off-by: anikaj <anikaj@nvidia.com>
… docstrings

Address the pre-commit CI failure on PR NVIDIA#14773 (post-rebase).

- Reflow the 91-char ValueError f-string in LayerNorm.forward to fit
  yapf's 80-char limit for legacy files.
- Rewrite first-line docstring summaries to be imperative (D401) on
  newly added or modified functions: _fused_relu2_quant,
  _fused_gelu_tanh_quant, both register_fake stubs in cpp_custom_ops.py,
  _forward_unfused, _forward_nvfp4_fused, _ln_quant_type,
  _try_attach_nvfp4_scale.
- Convert multi-line docstring summaries to single-line summary +
  blank line + body (D205/D212/D400) for all new tests in
  test_wan22_nvfp4_fusion_integration.py, test_fused_layernorm_quant.py,
  and the new test_fused_helpers_preserve_3d_prefix_dims test.

No functional changes; comments and docstrings only.

Signed-off-by: anikaj <anikaj@nvidia.com>
CI pre-commit on PR NVIDIA#14773 reported yapf and ruff-format would rewrite
two files. Apply the exact diff the hooks produced:

- layer_norm.py: yapf prefers leading-token line breaks; reflow the
  per_token_modulation guard, the _forward_unfused fallback call, and
  the AdaLN row-count ValueError accordingly.
- test_fused_layernorm_quant.py: ruff-format collapses the ln.forward
  call to a single line (now under 100 chars) and puts the parenthesis
  of the assert message on its own line.

Auto-generated formatting only; no behavior change.

Signed-off-by: anikaj <anikaj@nvidia.com>
Add a single-process benchmark that loads the calibrated ModelOpt
Wan2.2 T2V 14B transformer once and runs the same forward pass twice,
once with all fused-NVFP4 paths off and once with them on, each wrapped
in a distinct NVTX range so nsys can separate them cleanly.

Why a single process: keeps the safetensors load, autotuner cache, and
page cache identical between A and B, isolating the kernel-level effect
under test.

Usage (in the build container):

    # quick A/B wall-time delta
    python3 tests/unittest/_torch/visual_gen/bench_wan22_nvfp4_fusion.py \
        --warmup 5 --iters 20

    # full nsys A/B trace
    nsys profile -o wan22_ab.nsys-rep \
        --trace=cuda,nvtx,osrt --cuda-event-trace=true \
        python3 tests/unittest/_torch/visual_gen/bench_wan22_nvfp4_fusion.py \
            --warmup 5 --iters 20
    nsys stats --report nvtx_sum --report nvtx_kern_sum wan22_ab.nsys-rep

The script is also a usable reproduction recipe for the PR's perf
claim (~1.176x); reviewers can run it locally on an SM100+ host.

Checkpoint resolution mirrors the integration test: prefer
``WAN22_T2V_NVFP4_PATH``, fall back to ``/models/...`` build-container
bind-mount, then to the repo-conventional
``/home/scratch.trt_llm_data_ci/llm-models/`` CI mount.

Signed-off-by: anikaj <anikaj@nvidia.com>
@anikaj-eng anikaj-eng force-pushed the feat/wan22-nvfp4-fusion branch from 32e2ab0 to 2a08165 Compare June 4, 2026 23:30
Signed-off-by: anikaj <anikaj@nvidia.com>
Signed-off-by: anikaj <anikaj@nvidia.com>
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52188 [ run ] triggered by Bot. Commit: 17e8d6b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52188 [ run ] completed with state SUCCESS. Commit: 17e8d6b
/LLM/main/L0_MergeRequest_PR pipeline #41508 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52275 [ run ] triggered by Bot. Commit: 17e8d6b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52275 [ run ] completed with state SUCCESS. Commit: 17e8d6b
/LLM/main/L0_MergeRequest_PR pipeline #41586 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52398 [ run ] triggered by Bot. Commit: 17e8d6b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52398 [ run ] completed with state SUCCESS. Commit: 17e8d6b
/LLM/main/L0_MergeRequest_PR pipeline #41691 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52527 [ run ] triggered by Bot. Commit: 17e8d6b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52527 [ run ] completed with state FAILURE. Commit: 17e8d6b
/LLM/main/L0_MergeRequest_PR pipeline #41814 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52544 [ run ] triggered by Bot. Commit: 17e8d6b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52544 [ run ] completed with state SUCCESS. Commit: 17e8d6b
/LLM/main/L0_MergeRequest_PR pipeline #41829 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52603 [ run ] triggered by Bot. Commit: 17e8d6b Link to invocation

Comment on lines +115 to +119
candidates = [
os.environ.get("WAN22_T2V_NVFP4_PATH"),
"/models/Wan2.2-T2V-A14B-Diffusers-NVFP4",
"/home/scratch.trt_llm_data_ci/llm-models/Wan2.2-T2V-A14B-Diffusers-NVFP4",
]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use utils to get the model dir

Comment on lines +953 to +961
logger.info(
"Wan NVFP4 fused-LayerNorm scale attach: norm1=%d/%d, norm2=%d/%d, norm3=%d/%d",
n_attached_n1,
total_blocks,
n_attached_n2,
total_blocks,
n_attached_n3,
total_blocks,
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this log?

# kill switch on an NVFP4-configured layer (the A/B case). Plain
# LayerNorm users skip this entirely so no NVTX push/pop overhead
# leaks into the shared hot path.
with nvtx_range("LN unfused", color="grey"):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For such fine-grain nvtx trace, we should leave to debug code locally.

# Kimi-K2.5, Flux, Flux2, ...). The label is what lets
# `nsys stats --report nvtx_sum` confirm the fast path fired
# per block.
with nvtx_range("LN+NVFP4 fused", color="green"):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For such fine-grain nvtx trace, we should leave to debug code locally.

"""
if quant_config is None or force_dynamic_quant:
return None
if os.environ.get("TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION", "1") == "1":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make fuse_nvfp4_ln a block-level local flag, like fuse_qk_norm_rope, not in config or using env variable.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bench files should not be in unittest folder. Do we have bench tests in visual gen? @zhenhuaw-me @chang-l

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52603 [ run ] completed with state FAILURE. Commit: 17e8d6b
/LLM/main/L0_MergeRequest_PR pipeline #41885 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering we may not modify the gerenal layer norm module and expose the quantize_type / is_nvfp4 / nvfp4_scale / _forward_nvfp4_fused / AdaLN scale_msa,shift_msa,seq_len_per_batch args to it. I suggest using class like DiTLayerNorm(LayerNorm) to do the extra work here.

Copy link
Copy Markdown
Collaborator

@luyiyun1021 luyiyun1021 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work. Left some comments. Could you please trim your comments seems lots of comments added to the generl module like mlp or layernorm and some are agent dev comments and verbosed.

Copy link
Copy Markdown
Collaborator

@luyiyun1021 luyiyun1021 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another pass focused on the MLP fused path -- structure/DRY suggestions (non-blocking) plus one gate-consistency question. The fusion approach matches the existing GatedMLP / relu2 precedent; these are about tightening the implementation to that same single-parameterized-helper style.

not_dynamic = not getattr(self.down_proj, "force_dynamic_quantization",
False)

self._use_fused_gelu_tanh_quant = (has_nvfp4 and has_kernel_gelu
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit · mergeable-branches + gate inconsistency]

The two parallel _use_fused_* booleans here + the matching if/elif in forward differ only by which helper they call and the NVTX label. Consider collapsing to a single resolved op handle, mirroring how GatedMLP does the same NVFP4 activation+quant fusion (gated_mlp.py: _can_fuse_gate_up_swiglu* -> one parameterized _fused_gate_up_swiglu(x, fp4_out=...)):

# create_weights
self._fused_act_quant_op = None
if has_nvfp4 and has_scale and not_dynamic:
    if self.activation is relu2 and has_kernel:
        self._fused_act_quant_op = torch.ops.trtllm.fused_relu2_quantize
    elif self.activation is gelu_tanh and has_kernel_gelu:
        self._fused_act_quant_op = torch.ops.trtllm.fused_gelu_tanh_quantize

# forward
if self._fused_act_quant_op is not None:
    with nvtx_range("act+NVFP4 fused", color="green"):
        x_act = self._fused_act_quant(x_up, self._fused_act_quant_op)
else:
    x_act = self.activation(x_up)

Separately, the two gates are inconsistent: this gelu gate requires not_dynamic (the comment explains it guards a stale module.alpha when an Fp4QuantizedTensor is fed under dynamic quant), but the relu2 gate just above (self._use_fused_relu2_quant = ...) does not include not_dynamic. relu2 also returns an Fp4QuantizedTensor and hits the same alpha path -- should relu2 also gate on not_dynamic, or is the gelu not_dynamic over-cautious? Unifying the gates would force this to be decided in one place.

is_sf_swizzled=True,
)

def _fused_gelu_tanh_quant(self, x: torch.Tensor) -> Fp4QuantizedTensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit · DRY] _fused_gelu_tanh_quant is identical to _fused_relu2_quant (just above) except for the op symbol (fused_gelu_tanh_quantize vs fused_relu2_quantize). Worth collapsing into one op-parameterized helper:

def _fused_act_quant(self, x, op):
    orig = x.shape
    xf = x.view(-1, x.shape[-1]).contiguous()
    if xf.dtype not in (torch.float16, torch.bfloat16):
        xf = xf.to(torch.bfloat16)
    fp4, sf = op(xf, self.down_proj.input_scale, 16)
    if len(orig) > 2:
        fp4 = fp4.view(*orig[:-1], orig[-1] // 2)
    return Fp4QuantizedTensor(fp4, sf, is_sf_swizzled=True)

This mirrors GatedMLP._fused_gate_up_swiglu(x, fp4_out=...) (one parameterized helper). Note the CUDA side was already unified into the templated fusedActQuantizeKernel<T, Act> -- these Python helpers are the remaining duplicated half.

Copy link
Copy Markdown
Collaborator

@luyiyun1021 luyiyun1021 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test-convention nits: prefer the shared skip_pre_blackwell / skip_blackwell_geforce decorators (tests/unittest/utils/util.py) over hand-rolled getSMVersion() < 100 / _has_sm100() gates, so the Blackwell-version policy stays in one place. The integration-test one also surfaces a question about the < 120 (SM120) exclusion in the production is_nvfp4 gate.

return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "fp4_quantize")


skip_unless_fused_layernorm_and_fp4 = pytest.mark.skipif(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit · follow-existing-patterns] The getSMVersion() < 100 half here re-rolls the repo-standard skip_pre_blackwell (tests/unittest/utils/util.py:113), which is what sibling kernel tests use (test_fused_moe, test_awq_quantization, test_mla_helix, test_modeling_kimi_k25, the dsa tests, ...). The op-availability AND (not fused_layernorm_quantize_available()) is a legitimate extra condition, but the SM part should compose the shared decorator instead of duplicating the < 100 literal:

from tests.unittest.utils.util import skip_pre_blackwell

@skip_pre_blackwell
@pytest.mark.skipif(
    not (fused_layernorm_quantize_available() and fp4_quantize_available()),
    reason="requires trtllm fused_layernorm_quantize + fp4_quantize ops",
)
def test_...():
    ...

Same applies to the gelu gates in test_fused_activation_quant.py (skip_unless_fused_gelu_tanh_*). Keeps the Blackwell-version policy in one place so a future SM-range change doesn't have to be chased across several hand-rolled getSMVersion() < 100 literals.

# that specific failure; anything else (e.g., ImportError) should
# propagate so it isn't masked.
return False
return 100 <= sm < 120
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit · follow-existing-patterns + a question for the code] _has_sm100() reimplements two decorators that already exist in tests/unittest/utils/util.py: 100 <= sm is skip_pre_blackwell (:113), and the < 120 upper bound is exactly skip_blackwell_geforce (getSMVersion() == 120, :119). Consider composing them instead of a hand-rolled helper (drop-in, since skip_if_no_sm100 is used as a plain decorator):

from tests.unittest.utils.util import skip_pre_blackwell, skip_blackwell_geforce

@skip_pre_blackwell
@skip_blackwell_geforce
@skip_if_no_checkpoint
def test_...():
    ...

The < 120 bound mirrors the production gate in layer_norm.py (is_nvfp4 is cleared unless 100 <= sm < 120), so excluding SM120 here is consistent with the code under test. But it raises a real question about the code itself: why is SM120 excluded from the fused path at all? SM120 is Blackwell and supports cvt.rn.satfinite.e2m1x2.f32, so the kernel should run there. If the exclusion is intentional (e.g., the kernel was only validated on SM100/SM10x), a one-line comment at the is_nvfp4 gate would help; if not, both the gate and this test should allow sm >= 100.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants