Skip to content

[TRTLLM-12288][feat] Support Nemotron-H nvfp4 ckpt on Hopper#14775

Open
JadoTu wants to merge 2 commits into
NVIDIA:mainfrom
JadoTu:support_nemotron_nvfp4_on_hopper
Open

[TRTLLM-12288][feat] Support Nemotron-H nvfp4 ckpt on Hopper#14775
JadoTu wants to merge 2 commits into
NVIDIA:mainfrom
JadoTu:support_nemotron_nvfp4_on_hopper

Conversation

@JadoTu
Copy link
Copy Markdown
Collaborator

@JadoTu JadoTu commented May 30, 2026

Summary by CodeRabbit

  • New Features

    • Added W4A16 NVFP4 quantization support with automatic fallback handling for Hopper architecture.
  • Tests

    • Added accuracy validation test for W4A16 NVFP4 quantization on 4-GPU Hopper systems.

Review Change Stack

Description

Enables NVFP4 Nemotron-H checkpoints to run on Hopper GPUs, which lack a native NVFP4 tensor-core GEMM. We add a W4A16 path that loads the NVFP4 weights and dequantizes them on-the-fly per forward step via Triton kernel. CUDA-graph capturable; Blackwell paths are untouched.

  1. This method will fallback different moe backends to one, which is a new path for nvfp4 on hopper.
  2. The dequantization of nvfp4 is done by triton, which needs further improvement.
  3. Functional tests have passed with Nemotron ultra/super on h100/h200, enabling MTP.
  4. The perf now is not tuned. Here is a example of ultra model running the full dataset of gsm8k on 8xh100: 13 mins.
  5. A new ci test is added.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
@JadoTu JadoTu requested review from a team as code owners May 30, 2026 04:18
@JadoTu
Copy link
Copy Markdown
Collaborator Author

JadoTu commented May 30, 2026

/bot run

@JadoTu JadoTu requested review from Wanli-Jiang and nv-guomingz May 30, 2026 04:20
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51160 [ run ] triggered by Bot. Commit: 047bc21 Link to invocation

@JadoTu JadoTu requested a review from tijyojwad May 30, 2026 04:30
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 30, 2026

📝 Walkthrough

Walkthrough

This PR implements a W4A16 quantization fallback for NVFP4 on Hopper GPUs, enabling dequantization-based execution when FP4 hardware is unavailable. The changes span Triton kernels, quantization method classes, MoE and linear execution paths, model-level integration, and validation tests.

Changes

Hopper W4A16 NVFP4 Fallback

Layer / File(s) Summary
Triton NVFP4 dequantization kernels
tensorrt_llm/_torch/modules/fused_moe/triton_dequant_nvfp4.py
Adds Triton-based dequantization for MoE (active-expert-only 3D path with active mask) and linear (2D path with per-tensor scale), with E2M1 FP4 codebook lookup, per-block FP8 scale conversion, and CUDA-graph-safe active expert masking via scatter_.
MoE W4A16 execution path
tensorrt_llm/_torch/modules/fused_moe/quantization.py, tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
Introduces W4A16NVFP4CutlassFusedMoEMethod for weight unswizzling and active-expert dequantization, plus CutlassFusedMoE dispatch that bypasses FP4 quantization and delegates to _run_moe_w4a16_nvfp4, which dequantizes active experts and calls torch.ops.trtllm.fused_moe with empty scales.
Linear W4A16 execution method
tensorrt_llm/_torch/modules/linear.py
Adds W4A16NVFP4LinearMethod for weight-scale preparation, FP8 input inversion, and on-the-fly NVFP4 weight dequantization via Triton, executing GEMM through cublas_mm or F.linear with output reshaping.
Nemotron-H model Hopper fallback wiring
tensorrt_llm/_torch/models/modeling_nemotron_h.py
Implements FP4-hardware detection in NemotronHLayer, disables NVFP4 on Mamba mixer for SM<100, and adds context manager patching of quantization methods at NemotronHForCausalLM initialization and post-initialization to apply W4A16 fallback during model construction.
Hopper W4A16 accuracy test
tests/integration/defs/accuracy/test_llm_api_pytorch.py, tests/integration/test_lists/test-db/l0_dgx_h100.yml
Adds test_nvfp4_4gpus_hopper_w4a16 test for Nemotron-3-Super-120B-A12B-NVFP4 on H100 with W4A16 fallback semantics, registered in the pre-merge H100 4-GPU test suite.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • danielafrimi
  • StanleySun639
  • nv-guomingz
  • lfr-0531
  • yuxianq
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.16% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely identifies the main feature: adding NVFP4 Nemotron-H checkpoint support on Hopper GPUs.
Description check ✅ Passed The description explains the motivation (NVFP4 on Hopper lacks native GEMM), the solution approach (W4A16 dequantization path), key design notes, testing status, and performance context. However, it lacks specific details about Test Coverage as required by the template.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py`:
- Around line 793-797: The code unconditionally sets
TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT="0" process-wide; instead, capture the
previous os.environ.get("TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT"), then set the
env var only within the same context manager used for the class-level patches
(the existing patch context used in this module) and ensure you restore the
original value in a finally block so other model inits don't inherit the change;
reference the env var name TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT, the logger call
using get_sm_version(), and wrap the assignment/usage in try/finally (or use the
existing patch context manager) to guarantee restoration.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 2905-2919: process_weights_after_loading currently unswizzles only
resident expert scales (module.w3_w1_weight_scale, module.w2_weight_scale) after
calling super(), but super().process_weights_after_loading() finalizes and
deletes shared EPLB buffers (local_shared_*_scale_tensors) so migrated experts
remain swizzled and later dequant_active_experts_to_hp() misbehaves; fix by
handling shared EPLB scale buffers before super() is called (e.g., detect and
call the existing _unswizzle_inplace on module.local_shared_*_scale_tensors /
any module.local_shared_{w1,w2,w3}_weight_scale if present) or alternatively
mark eplb_support_status = UNSUPPORTED for this class so EPLB paths are disabled
for this override. Ensure references to process_weights_after_loading,
_unswizzle_inplace, module.local_shared_*_scale_tensors,
module.w3_w1_weight_scale, module.w2_weight_scale, eplb_support_status and
dequant_active_experts_to_hp are used to locate and update the code.

In `@tensorrt_llm/_torch/modules/fused_moe/triton_dequant_nvfp4.py`:
- Around line 155-194: The wrapper that launches _dequant_nvfp4_active_kernel
must validate tensor contiguity/strides and companion shapes before launching:
assert that the innermost (K-packed) dimension is unit-stride for packed_weight
(packed_weight.stride(2)==1) and for scale_linear's last dimension
(scale_linear.stride(1)==1), ensure active_mask has unit stride/contiguous
layout for its indexing (e.g., active_mask.is_contiguous() or
active_mask.stride(0)==1), and validate weight_scale_2 is scalar (numel()==1)
when the 2D path expects a single element; add equivalent checks in the other
wrapper (the 290-333 block that launches the non-active kernel) so kernels
cannot silently read from wrong addresses. Ensure assertions include descriptive
messages naming the offending tensor and expected constraint.

In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 1891-1896: Replace the current assert with a real guard that
ensures static scales were actually loaded and dynamic quantization is not
forced: when handling FP8 input (input.dtype == torch.float8_e4m3fn) check
module.force_dynamic_quantization and module.inv_input_scale explicitly—if
force_dynamic_quantization is false but module.inv_input_scale is None, raise a
clear RuntimeError (fail fast); only perform the division by
module.inv_input_scale when module.inv_input_scale is present and
module.force_dynamic_quantization is false. Reference
NVFP4LinearMethod.create_weights(), module.inv_input_scale and
module.force_dynamic_quantization to locate and update the logic around the FP8
input handling.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b2225cc5-e534-4968-a8fd-7877c046fab5

📥 Commits

Reviewing files that changed from the base of the PR and between 15bb791 and 047bc21.

📒 Files selected for processing (7)
  • tensorrt_llm/_torch/models/modeling_nemotron_h.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
  • tensorrt_llm/_torch/modules/fused_moe/triton_dequant_nvfp4.py
  • tensorrt_llm/_torch/modules/linear.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tests/integration/test_lists/test-db/l0_dgx_h100.yml

Comment on lines +793 to +797
if os.environ.get("TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT") != "0":
logger.warning(
f"Nemotron-H SM{get_sm_version()}: TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT=0"
)
os.environ["TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT"] = "0"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't leave the attention fallback env var set process-wide.

This constructor flips TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT to 0 and never restores it. A later model init in the same worker will inherit bf16 attention output even when its FP4 path should stay enabled. Please scope this override with the same context manager used for the class-level patches and restore the previous value in finally.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py` around lines 793 - 797,
The code unconditionally sets TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT="0"
process-wide; instead, capture the previous
os.environ.get("TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT"), then set the env var
only within the same context manager used for the class-level patches (the
existing patch context used in this module) and ensure you restore the original
value in a finally block so other model inits don't inherit the change;
reference the env var name TRTLLM_ENABLE_ATTENTION_NVFP4_OUTPUT, the logger call
using get_sm_version(), and wrap the assignment/usage in try/finally (or use the
existing patch context manager) to guarantee restoration.

Comment on lines +2905 to +2919
def process_weights_after_loading(self, module: torch.nn.Module):
super().process_weights_after_loading(module)

# Scale buffer: int32-packed FP8, viewed as uint8 has shape
# [E, pad_up(N, 128), pad_up(K/sf_vec, 4)] -- the 3D layout
# block_scale_interleave_reverse accepts.
def _unswizzle_inplace(scale_param: torch.nn.Parameter):
sf_view = scale_param.data.view(float4_sf_dtype)
E, pad_rows, pad_cols = (sf_view.shape[0], sf_view.shape[1],
sf_view.shape[2])
linear = torch.ops.trtllm.block_scale_interleave_reverse(sf_view)
scale_param.data.view(float4_sf_dtype).copy_(linear)

_unswizzle_inplace(module.w3_w1_weight_scale)
_unswizzle_inplace(module.w2_weight_scale)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Unswizzle the shared EPLB scale buffers too, or disable EPLB for this method.

super().process_weights_after_loading() finalizes and deletes local_shared_*_scale_tensors before this override runs, so only the resident experts get converted back to linear layout here. The class still inherits eplb_support_status = SUPPORTED, though, and dequant_active_experts_to_hp() later treats migrated experts as if their scales were already unswizzled. Online EPLB will therefore dequantize migrated experts incorrectly on the Hopper fallback path.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 2905 -
2919, process_weights_after_loading currently unswizzles only resident expert
scales (module.w3_w1_weight_scale, module.w2_weight_scale) after calling
super(), but super().process_weights_after_loading() finalizes and deletes
shared EPLB buffers (local_shared_*_scale_tensors) so migrated experts remain
swizzled and later dequant_active_experts_to_hp() misbehaves; fix by handling
shared EPLB scale buffers before super() is called (e.g., detect and call the
existing _unswizzle_inplace on module.local_shared_*_scale_tensors / any
module.local_shared_{w1,w2,w3}_weight_scale if present) or alternatively mark
eplb_support_status = UNSUPPORTED for this class so EPLB paths are disabled for
this override. Ensure references to process_weights_after_loading,
_unswizzle_inplace, module.local_shared_*_scale_tensors,
module.w3_w1_weight_scale, module.w2_weight_scale, eplb_support_status and
dequant_active_experts_to_hp are used to locate and update the code.

Comment on lines +155 to +194
assert packed_weight.dim() == 3, "packed_weight must be 3D [E, N, K/2]"
assert sf_vec_size == 16, "NVFP4 fixed at 16-element blocks"
assert block_k % sf_vec_size == 0, (
f"block_k={block_k} must be a multiple of sf_vec_size={sf_vec_size}"
)

E, N, K_packed = packed_weight.shape
K = K_packed * 2
device = packed_weight.device

if active_mask.dtype != torch.uint8:
active_mask = active_mask.to(torch.uint8)

out = torch.empty(E, N, K, dtype=target_dtype, device=device)
e2m1_table = _get_e2m1_codebook(device)

grid = (E, triton.cdiv(N, block_n), triton.cdiv(K, block_k))
_dequant_nvfp4_active_kernel[grid](
packed_weight,
scale_linear,
weight_scale_2,
active_mask,
e2m1_table,
out,
# strides (in elements)
packed_weight.stride(0),
packed_weight.stride(1),
scale_linear.stride(0),
scale_linear.stride(1),
out.stride(0),
out.stride(1),
# shapes
N,
K,
# constexpr
SF_VEC=sf_vec_size,
BLOCK_N=block_n,
BLOCK_K=block_k,
)
return out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Validate the wrapper contract before launching Triton.

Both kernels index the innermost dimension as if it were unit-stride, but these wrappers only pass the leading strides and do not check the companion tensors’ shapes. A sliced/transposed packed_weight or weight_scale, a non-contiguous active_mask, or a multi-element weight_scale_2 in the 2D path will dequantize from the wrong addresses instead of failing fast.

Suggested guardrails
 def dequant_nvfp4_active_triton(
@@
     E, N, K_packed = packed_weight.shape
     K = K_packed * 2
     device = packed_weight.device
+
+    if packed_weight.stride(-1) != 1:
+        raise ValueError("packed_weight must be contiguous in the last dimension")
+    if scale_linear.dim() != 3 or scale_linear.shape[0] != E or scale_linear.shape[1] < N:
+        raise ValueError("scale_linear must have shape [E, N_pad, K_sf_pad] with matching E/N")
+    if scale_linear.stride(-1) != 1:
+        raise ValueError("scale_linear must be contiguous in the last dimension")
+    if active_mask.dim() != 1 or active_mask.numel() != E or active_mask.stride(0) != 1:
+        raise ValueError("active_mask must be a contiguous [E] tensor")
+    if weight_scale_2.dim() != 1 or weight_scale_2.numel() != E or weight_scale_2.stride(0) != 1:
+        raise ValueError("weight_scale_2 must be a contiguous [E] tensor")
@@
 def dequant_nvfp4_2d_triton(
@@
     N, K_packed = packed_weight.shape
     K = K_packed * 2
     device = packed_weight.device
+
+    if packed_weight.stride(-1) != 1:
+        raise ValueError("packed_weight must be contiguous in the last dimension")
@@
     elif weight_scale.dim() != 2:
         raise ValueError(f"weight_scale must be 1D or 2D, got shape {tuple(weight_scale.shape)}")
+    if weight_scale.shape[0] < N or weight_scale.shape[1] < K // sf_vec_size:
+        raise ValueError("weight_scale is smaller than the required padded scale grid")
+    if weight_scale.stride(-1) != 1:
+        raise ValueError("weight_scale must be contiguous in the last dimension")
+    if weight_scale_2.numel() != 1:
+        raise ValueError("weight_scale_2 must contain exactly one element")

Also applies to: 290-333

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/fused_moe/triton_dequant_nvfp4.py` around lines
155 - 194, The wrapper that launches _dequant_nvfp4_active_kernel must validate
tensor contiguity/strides and companion shapes before launching: assert that the
innermost (K-packed) dimension is unit-stride for packed_weight
(packed_weight.stride(2)==1) and for scale_linear's last dimension
(scale_linear.stride(1)==1), ensure active_mask has unit stride/contiguous
layout for its indexing (e.g., active_mask.is_contiguous() or
active_mask.stride(0)==1), and validate weight_scale_2 is scalar (numel()==1)
when the 2D path expects a single element; add equivalent checks in the other
wrapper (the 290-333 block that launches the non-active kernel) so kernels
cannot silently read from wrong addresses. Ensure assertions include descriptive
messages naming the offending tensor and expected constraint.

Comment on lines +1891 to +1896
## FP8 input from upstream FMHA pre-quant: invert by / module.inv_input_scale.
if input.dtype == torch.float8_e4m3fn:
assert module.inv_input_scale is not None, \
"W4A16NVFP4LinearMethod: FP8 input requires static inv_input_scale"
input = (input.to(module.dtype) / module.inv_input_scale).to(
module.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Use a real static-scale guard for FP8 input.

NVFP4LinearMethod.create_weights() always allocates module.inv_input_scale, so the assert on Line 1893 never trips. If this branch is hit without loaded static scales, Line 1895 divides by uninitialized data instead of failing fast. Gate this on a condition that actually proves static scales were loaded, plus not module.force_dynamic_quantization.

🧰 Tools
🪛 Ruff (0.15.14)

[error] 1895-1895: Variable input is shadowing a Python builtin

(A001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/linear.py` around lines 1891 - 1896, Replace the
current assert with a real guard that ensures static scales were actually loaded
and dynamic quantization is not forced: when handling FP8 input (input.dtype ==
torch.float8_e4m3fn) check module.force_dynamic_quantization and
module.inv_input_scale explicitly—if force_dynamic_quantization is false but
module.inv_input_scale is None, raise a clear RuntimeError (fail fast); only
perform the division by module.inv_input_scale when module.inv_input_scale is
present and module.force_dynamic_quantization is false. Reference
NVFP4LinearMethod.create_weights(), module.inv_input_scale and
module.force_dynamic_quantization to locate and update the logic around the FP8
input handling.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51160 [ run ] completed with state SUCCESS. Commit: 047bc21
/LLM/main/L0_MergeRequest_PR pipeline #40594 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@JadoTu
Copy link
Copy Markdown
Collaborator Author

JadoTu commented May 30, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51202 [ run ] triggered by Bot. Commit: b23cffd Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51202 [ run ] completed with state FAILURE. Commit: b23cffd
/LLM/main/L0_MergeRequest_PR pipeline #40630 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@JadoTu
Copy link
Copy Markdown
Collaborator Author

JadoTu commented May 31, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51219 [ run ] triggered by Bot. Commit: b23cffd Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51219 [ run ] completed with state SUCCESS. Commit: b23cffd
/LLM/main/L0_MergeRequest_PR pipeline #40642 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@JadoTu
Copy link
Copy Markdown
Collaborator Author

JadoTu commented May 31, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51242 [ run ] triggered by Bot. Commit: b23cffd Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51242 [ run ] completed with state SUCCESS. Commit: b23cffd
/LLM/main/L0_MergeRequest_PR pipeline #40665 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@JadoTu
Copy link
Copy Markdown
Collaborator Author

JadoTu commented May 31, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51254 [ run ] triggered by Bot. Commit: b23cffd Link to invocation

@JadoTu
Copy link
Copy Markdown
Collaborator Author

JadoTu commented May 31, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51259 [ run ] triggered by Bot. Commit: b23cffd Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51254 [ run ] completed with state ABORTED. Commit: b23cffd

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #51259 [ run ] completed with state SUCCESS. Commit: b23cffd
/LLM/main/L0_MergeRequest_PR pipeline #40680 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants