[TRTLLM-12288][feat] Support Nemotron-H nvfp4 ckpt on Hopper#14775
[TRTLLM-12288][feat] Support Nemotron-H nvfp4 ckpt on Hopper#14775JadoTu wants to merge 2 commits into
Conversation
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
|
/bot run |
|
PR_Github #51160 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR implements a W4A16 quantization fallback for NVFP4 on Hopper GPUs, enabling dequantization-based execution when FP4 hardware is unavailable. The changes span Triton kernels, quantization method classes, MoE and linear execution paths, model-level integration, and validation tests. ChangesHopper W4A16 NVFP4 Fallback
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py`:
- Around line 793-797: The code unconditionally sets
TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT="0" process-wide; instead, capture the
previous os.environ.get("TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT"), then set the
env var only within the same context manager used for the class-level patches
(the existing patch context used in this module) and ensure you restore the
original value in a finally block so other model inits don't inherit the change;
reference the env var name TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT, the logger call
using get_sm_version(), and wrap the assignment/usage in try/finally (or use the
existing patch context manager) to guarantee restoration.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 2905-2919: process_weights_after_loading currently unswizzles only
resident expert scales (module.w3_w1_weight_scale, module.w2_weight_scale) after
calling super(), but super().process_weights_after_loading() finalizes and
deletes shared EPLB buffers (local_shared_*_scale_tensors) so migrated experts
remain swizzled and later dequant_active_experts_to_hp() misbehaves; fix by
handling shared EPLB scale buffers before super() is called (e.g., detect and
call the existing _unswizzle_inplace on module.local_shared_*_scale_tensors /
any module.local_shared_{w1,w2,w3}_weight_scale if present) or alternatively
mark eplb_support_status = UNSUPPORTED for this class so EPLB paths are disabled
for this override. Ensure references to process_weights_after_loading,
_unswizzle_inplace, module.local_shared_*_scale_tensors,
module.w3_w1_weight_scale, module.w2_weight_scale, eplb_support_status and
dequant_active_experts_to_hp are used to locate and update the code.
In `@tensorrt_llm/_torch/modules/fused_moe/triton_dequant_nvfp4.py`:
- Around line 155-194: The wrapper that launches _dequant_nvfp4_active_kernel
must validate tensor contiguity/strides and companion shapes before launching:
assert that the innermost (K-packed) dimension is unit-stride for packed_weight
(packed_weight.stride(2)==1) and for scale_linear's last dimension
(scale_linear.stride(1)==1), ensure active_mask has unit stride/contiguous
layout for its indexing (e.g., active_mask.is_contiguous() or
active_mask.stride(0)==1), and validate weight_scale_2 is scalar (numel()==1)
when the 2D path expects a single element; add equivalent checks in the other
wrapper (the 290-333 block that launches the non-active kernel) so kernels
cannot silently read from wrong addresses. Ensure assertions include descriptive
messages naming the offending tensor and expected constraint.
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 1891-1896: Replace the current assert with a real guard that
ensures static scales were actually loaded and dynamic quantization is not
forced: when handling FP8 input (input.dtype == torch.float8_e4m3fn) check
module.force_dynamic_quantization and module.inv_input_scale explicitly—if
force_dynamic_quantization is false but module.inv_input_scale is None, raise a
clear RuntimeError (fail fast); only perform the division by
module.inv_input_scale when module.inv_input_scale is present and
module.force_dynamic_quantization is false. Reference
NVFP4LinearMethod.create_weights(), module.inv_input_scale and
module.force_dynamic_quantization to locate and update the logic around the FP8
input handling.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b2225cc5-e534-4968-a8fd-7877c046fab5
📒 Files selected for processing (7)
tensorrt_llm/_torch/models/modeling_nemotron_h.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.pytensorrt_llm/_torch/modules/fused_moe/quantization.pytensorrt_llm/_torch/modules/fused_moe/triton_dequant_nvfp4.pytensorrt_llm/_torch/modules/linear.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytests/integration/test_lists/test-db/l0_dgx_h100.yml
| if os.environ.get("TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT") != "0": | ||
| logger.warning( | ||
| f"Nemotron-H SM{get_sm_version()}: TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT=0" | ||
| ) | ||
| os.environ["TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT"] = "0" |
There was a problem hiding this comment.
Don't leave the attention fallback env var set process-wide.
This constructor flips TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT to 0 and never restores it. A later model init in the same worker will inherit bf16 attention output even when its FP4 path should stay enabled. Please scope this override with the same context manager used for the class-level patches and restore the previous value in finally.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py` around lines 793 - 797,
The code unconditionally sets TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT="0"
process-wide; instead, capture the previous
os.environ.get("TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT"), then set the env var
only within the same context manager used for the class-level patches (the
existing patch context used in this module) and ensure you restore the original
value in a finally block so other model inits don't inherit the change;
reference the env var name TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT, the logger call
using get_sm_version(), and wrap the assignment/usage in try/finally (or use the
existing patch context manager) to guarantee restoration.
| def process_weights_after_loading(self, module: torch.nn.Module): | ||
| super().process_weights_after_loading(module) | ||
|
|
||
| # Scale buffer: int32-packed FP8, viewed as uint8 has shape | ||
| # [E, pad_up(N, 128), pad_up(K/sf_vec, 4)] -- the 3D layout | ||
| # block_scale_interleave_reverse accepts. | ||
| def _unswizzle_inplace(scale_param: torch.nn.Parameter): | ||
| sf_view = scale_param.data.view(float4_sf_dtype) | ||
| E, pad_rows, pad_cols = (sf_view.shape[0], sf_view.shape[1], | ||
| sf_view.shape[2]) | ||
| linear = torch.ops.trtllm.block_scale_interleave_reverse(sf_view) | ||
| scale_param.data.view(float4_sf_dtype).copy_(linear) | ||
|
|
||
| _unswizzle_inplace(module.w3_w1_weight_scale) | ||
| _unswizzle_inplace(module.w2_weight_scale) |
There was a problem hiding this comment.
Unswizzle the shared EPLB scale buffers too, or disable EPLB for this method.
super().process_weights_after_loading() finalizes and deletes local_shared_*_scale_tensors before this override runs, so only the resident experts get converted back to linear layout here. The class still inherits eplb_support_status = SUPPORTED, though, and dequant_active_experts_to_hp() later treats migrated experts as if their scales were already unswizzled. Online EPLB will therefore dequantize migrated experts incorrectly on the Hopper fallback path.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 2905 -
2919, process_weights_after_loading currently unswizzles only resident expert
scales (module.w3_w1_weight_scale, module.w2_weight_scale) after calling
super(), but super().process_weights_after_loading() finalizes and deletes
shared EPLB buffers (local_shared_*_scale_tensors) so migrated experts remain
swizzled and later dequant_active_experts_to_hp() misbehaves; fix by handling
shared EPLB scale buffers before super() is called (e.g., detect and call the
existing _unswizzle_inplace on module.local_shared_*_scale_tensors / any
module.local_shared_{w1,w2,w3}_weight_scale if present) or alternatively mark
eplb_support_status = UNSUPPORTED for this class so EPLB paths are disabled for
this override. Ensure references to process_weights_after_loading,
_unswizzle_inplace, module.local_shared_*_scale_tensors,
module.w3_w1_weight_scale, module.w2_weight_scale, eplb_support_status and
dequant_active_experts_to_hp are used to locate and update the code.
| assert packed_weight.dim() == 3, "packed_weight must be 3D [E, N, K/2]" | ||
| assert sf_vec_size == 16, "NVFP4 fixed at 16-element blocks" | ||
| assert block_k % sf_vec_size == 0, ( | ||
| f"block_k={block_k} must be a multiple of sf_vec_size={sf_vec_size}" | ||
| ) | ||
|
|
||
| E, N, K_packed = packed_weight.shape | ||
| K = K_packed * 2 | ||
| device = packed_weight.device | ||
|
|
||
| if active_mask.dtype != torch.uint8: | ||
| active_mask = active_mask.to(torch.uint8) | ||
|
|
||
| out = torch.empty(E, N, K, dtype=target_dtype, device=device) | ||
| e2m1_table = _get_e2m1_codebook(device) | ||
|
|
||
| grid = (E, triton.cdiv(N, block_n), triton.cdiv(K, block_k)) | ||
| _dequant_nvfp4_active_kernel[grid]( | ||
| packed_weight, | ||
| scale_linear, | ||
| weight_scale_2, | ||
| active_mask, | ||
| e2m1_table, | ||
| out, | ||
| # strides (in elements) | ||
| packed_weight.stride(0), | ||
| packed_weight.stride(1), | ||
| scale_linear.stride(0), | ||
| scale_linear.stride(1), | ||
| out.stride(0), | ||
| out.stride(1), | ||
| # shapes | ||
| N, | ||
| K, | ||
| # constexpr | ||
| SF_VEC=sf_vec_size, | ||
| BLOCK_N=block_n, | ||
| BLOCK_K=block_k, | ||
| ) | ||
| return out |
There was a problem hiding this comment.
Validate the wrapper contract before launching Triton.
Both kernels index the innermost dimension as if it were unit-stride, but these wrappers only pass the leading strides and do not check the companion tensors’ shapes. A sliced/transposed packed_weight or weight_scale, a non-contiguous active_mask, or a multi-element weight_scale_2 in the 2D path will dequantize from the wrong addresses instead of failing fast.
Suggested guardrails
def dequant_nvfp4_active_triton(
@@
E, N, K_packed = packed_weight.shape
K = K_packed * 2
device = packed_weight.device
+
+ if packed_weight.stride(-1) != 1:
+ raise ValueError("packed_weight must be contiguous in the last dimension")
+ if scale_linear.dim() != 3 or scale_linear.shape[0] != E or scale_linear.shape[1] < N:
+ raise ValueError("scale_linear must have shape [E, N_pad, K_sf_pad] with matching E/N")
+ if scale_linear.stride(-1) != 1:
+ raise ValueError("scale_linear must be contiguous in the last dimension")
+ if active_mask.dim() != 1 or active_mask.numel() != E or active_mask.stride(0) != 1:
+ raise ValueError("active_mask must be a contiguous [E] tensor")
+ if weight_scale_2.dim() != 1 or weight_scale_2.numel() != E or weight_scale_2.stride(0) != 1:
+ raise ValueError("weight_scale_2 must be a contiguous [E] tensor")
@@
def dequant_nvfp4_2d_triton(
@@
N, K_packed = packed_weight.shape
K = K_packed * 2
device = packed_weight.device
+
+ if packed_weight.stride(-1) != 1:
+ raise ValueError("packed_weight must be contiguous in the last dimension")
@@
elif weight_scale.dim() != 2:
raise ValueError(f"weight_scale must be 1D or 2D, got shape {tuple(weight_scale.shape)}")
+ if weight_scale.shape[0] < N or weight_scale.shape[1] < K // sf_vec_size:
+ raise ValueError("weight_scale is smaller than the required padded scale grid")
+ if weight_scale.stride(-1) != 1:
+ raise ValueError("weight_scale must be contiguous in the last dimension")
+ if weight_scale_2.numel() != 1:
+ raise ValueError("weight_scale_2 must contain exactly one element")Also applies to: 290-333
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/modules/fused_moe/triton_dequant_nvfp4.py` around lines
155 - 194, The wrapper that launches _dequant_nvfp4_active_kernel must validate
tensor contiguity/strides and companion shapes before launching: assert that the
innermost (K-packed) dimension is unit-stride for packed_weight
(packed_weight.stride(2)==1) and for scale_linear's last dimension
(scale_linear.stride(1)==1), ensure active_mask has unit stride/contiguous
layout for its indexing (e.g., active_mask.is_contiguous() or
active_mask.stride(0)==1), and validate weight_scale_2 is scalar (numel()==1)
when the 2D path expects a single element; add equivalent checks in the other
wrapper (the 290-333 block that launches the non-active kernel) so kernels
cannot silently read from wrong addresses. Ensure assertions include descriptive
messages naming the offending tensor and expected constraint.
| ## FP8 input from upstream FMHA pre-quant: invert by / module.inv_input_scale. | ||
| if input.dtype == torch.float8_e4m3fn: | ||
| assert module.inv_input_scale is not None, \ | ||
| "W4A16NVFP4LinearMethod: FP8 input requires static inv_input_scale" | ||
| input = (input.to(module.dtype) / module.inv_input_scale).to( | ||
| module.dtype) |
There was a problem hiding this comment.
Use a real static-scale guard for FP8 input.
NVFP4LinearMethod.create_weights() always allocates module.inv_input_scale, so the assert on Line 1893 never trips. If this branch is hit without loaded static scales, Line 1895 divides by uninitialized data instead of failing fast. Gate this on a condition that actually proves static scales were loaded, plus not module.force_dynamic_quantization.
🧰 Tools
🪛 Ruff (0.15.14)
[error] 1895-1895: Variable input is shadowing a Python builtin
(A001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/modules/linear.py` around lines 1891 - 1896, Replace the
current assert with a real guard that ensures static scales were actually loaded
and dynamic quantization is not forced: when handling FP8 input (input.dtype ==
torch.float8_e4m3fn) check module.force_dynamic_quantization and
module.inv_input_scale explicitly—if force_dynamic_quantization is false but
module.inv_input_scale is None, raise a clear RuntimeError (fail fast); only
perform the division by module.inv_input_scale when module.inv_input_scale is
present and module.force_dynamic_quantization is false. Reference
NVFP4LinearMethod.create_weights(), module.inv_input_scale and
module.force_dynamic_quantization to locate and update the logic around the FP8
input handling.
|
PR_Github #51160 [ run ] completed with state
|
|
/bot run |
|
PR_Github #51202 [ run ] triggered by Bot. Commit: |
|
PR_Github #51202 [ run ] completed with state
|
|
/bot run |
|
PR_Github #51219 [ run ] triggered by Bot. Commit: |
|
PR_Github #51219 [ run ] completed with state
|
|
/bot run |
|
PR_Github #51242 [ run ] triggered by Bot. Commit: |
|
PR_Github #51242 [ run ] completed with state
|
|
/bot run |
|
PR_Github #51254 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #51259 [ run ] triggered by Bot. Commit: |
|
PR_Github #51254 [ run ] completed with state |
|
PR_Github #51259 [ run ] completed with state |
Summary by CodeRabbit
New Features
Tests
Description
Enables NVFP4 Nemotron-H checkpoints to run on Hopper GPUs, which lack a native NVFP4 tensor-core GEMM. We add a W4A16 path that loads the NVFP4 weights and dequantizes them on-the-fly per forward step via Triton kernel. CUDA-graph capturable; Blackwell paths are untouched.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.