Skip to content

Introduce Mega-C++ to reduce CPU overhead#3099

Open
zhongbozhu wants to merge 8 commits into
NVIDIA:mainfrom
zhongbozhu:main_megacpp_grouped_mlp
Open

Introduce Mega-C++ to reduce CPU overhead#3099
zhongbozhu wants to merge 8 commits into
NVIDIA:mainfrom
zhongbozhu:main_megacpp_grouped_mlp

Conversation

@zhongbozhu

@zhongbozhu zhongbozhu commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

Description

Assistant: GPT5.5 codex

Issue: #2897

Get rid of CPU overhead whenever CUDA Graph is not applicable. Guarded by NVTE_MEGACPP_GROUPED_LINEAR.

Drop-in replace grouped MLP, ie. FC1 - act - FC2. Target BF16 grouped gemm with cublas grouped gemm backend.

In the future, we can extend to mxfp8 / nvfp4 with cublas backend or even cuteDSL grouped gemm and call cute.jit in C++: NVIDIA/cutlass#3289

Recommend CUDA >= 13.2.1

Dependency of merge: #3132 => this PR is rebased on top of this branch.

TODO:

  • E2E training with some multimodal THD packing workloads
  • Develop a better fused scaled swiglu kernel to replace torch ops
  • Attach before & after screenshots

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 6, 2026
@zhongbozhu

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

m.def("te_general_grouped_gemm_for_discrete_out",
&transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out,
"Grouped GEMM for discrete output list");
m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should expose these functions within the tex.grouped_mlp_experimental submodule:

// Experimental fused grouped MLP
auto grouped_mlp_experimental = m.def_submodule(
"grouped_mlp_experimental",
"Experimental helpers for the fused grouped MLP (unstable, may change or disappear).");

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make more sense to organize:

csrc/
├── extensions/
│   ├── grouped_mlp_experimental/
│   │   ├── megacpp.cpp
│   │   └── grouped_mlp_experimental.cpp
│   ├── pybind.cpp
│   └── ...

If we implement more mega-C++ impls in the future, I don't see a reason why they would be more similar to each other than to the block they are fusing.

name: str
is_scaled: bool
is_gated: bool
glu_interleave_size: int

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth supporting GLU interleaving in the mega-C++ path? The only benefit is to support the fused GEMM+GLU kernel, and otherwise the unnecessary memory-bound kernel means perf is a lost cause. If we can simplify our optimized code paths, then it's worth it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only benefit is to support the fused GEMM+GLU kernel

I do hope in the future we can launch CuteDSL fused kernels in C++ with some TVM-FFI tricks, otherwise we are forced to choose either better kernel fusions or less CPU overhead. Currently the CuteDSL fusion path is very CPU bounded for small models and we rely on CUDA graph and paged stashing for it to work well

@timmoon10 timmoon10 Jun 10, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make a separate fused op for mega-C++ CuTe DSL? It'll make the implementations less entangled, so there are fewer edge cases or complications that an agent might misunderstand.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, if there is a cutedsl version with mega C++, it's gonna take its own code path with zero code reuse since it's gonna be agent-assisted coding anyway

Comment on lines +377 to +380
# Explicit env opt-in gives megacpp first chance. Unsupported recipes intentionally
# return the ops unchanged so lower-priority recipe-specific fusers remain the
# fallback path.
register_forward_fusion(fuse_forward_megacpp_ops, prepend=True)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GEMM+act fusions provide better GPU perf, so I think they should take higher priority than mega-C++. Basically, I see mega-C++ as "we can't do any better on GPU than the unfused impl, but at least we can make the CPU overhead very small".

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current order is follows:

  1. check env var
  2. env var = 1, then check supported recipe for mega-C++, so bf16 is supported, not mxfp8 / nvfp4
  3. then for mxfp8, nvfp4, mega-C++ does fallback and check for the next fusion.

The reasoning is that, I do not want the compromise of either better fusion or less host bound, so for future mxfp8 support, we can do the following two things:

  1. directly do cuteDSL integration directly with tvm-ffi and do cublas as a backup plan
  2. maybe add a new value to NVTE_MEGACPP_GROUPED_LINEAR=forced, so for users who cannot enable cuda graph for some reason, they can enforce C++ when they know that their training is more host bound

@zhongbozhu zhongbozhu force-pushed the main_megacpp_grouped_mlp branch 2 times, most recently from 7ab8bc6 to 08a5800 Compare June 11, 2026 17:47
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu

zhongbozhu commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator Author

E2E test - Qwen3.5 VL 35B-A3B, BSHD layout - about 6% E2E
image

@zhongbozhu zhongbozhu force-pushed the main_megacpp_grouped_mlp branch from 08a5800 to 9d91d47 Compare June 16, 2026 07:26
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu force-pushed the main_megacpp_grouped_mlp branch from 23ae840 to 07b2836 Compare June 16, 2026 07:32
@zhongbozhu zhongbozhu marked this pull request as ready for review June 16, 2026 07:32
@greptile-apps

greptile-apps Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces "Mega-C++" (NVTE_MEGACPP_GROUPED_LINEAR env-gate), a new C++ path for grouped MLP (FC1 → activation → FC2) that replaces the existing Python/PyTorch op-loop to reduce CPU overhead on non-CUDA-graph paths. The implementation targets BF16/FP16 grouped GEMM via cuBLAS grouped-GEMM backend, with support for SwiGLU, ClampedSwiGLU, and SReLU activations.

  • New ForwardGroupedMLP_MegaCpp / BackwardGroupedMLP_MegaCpp Python fused-op classes plus the C++ binding megacpp_grouped_mlp_forward/backward handle the full FC1-activation-FC2 pipeline in a single dispatch, with Megatron main_grad and paged-stashing support.
  • New nvte_scaled_swiglu, nvte_scaled_clamped_swiglu, nvte_scaled_srelu and their backward counterparts add vectorized, interleave-aware CUDA kernels with optional per-row scale-gradient reduction (one block per row, warp-reduce).
  • A per-stream scratch-buffer cache (_cached_grouped_gemm_scratch) is introduced for cuBLAS workspace reuse, gated by the CUDA stream handle.

Confidence Score: 4/5

The functional path — forward GEMM → scaled activation → backward GEMM → wgrad — is well-implemented and covered by integration tests. All findings are quality/efficiency concerns rather than correctness failures in the happy path.

The core BF16 grouped MLP pipeline is logically correct and backed by both C++ unit tests and Python integration tests. Issues found: dead-code branch that can confuse future recipe additions; unbounded per-stream scratch cache that can accumulate GPU allocations; overly broad input_requires_grad causing unnecessary dx computation in frozen-input scenarios; and a delay_wgrad incompatibility caught only at backward time. None affect correctness for the supported BF16/FP16, no-delay_wgrad configuration.

forward_grouped_mlp_megacpp.py warrants a second look for the recipe predicate, LRU cache bound, and input_requires_grad assignment. backward_grouped_mlp_megacpp.py should move the delay_wgrad guard earlier in the fusion lifecycle.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py Core Python forward fuser: dead-code in _megacpp_supports_recipe, unbounded LRU cache keyed on CUDA stream handles, and overly-broad input_requires_grad assignment.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py Backward fuser with correct wgrad ownership policies and bias-grad assembly; delay_wgrad error is only raised during backward, not at fusion time.
transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp C++ binding layer for forward/backward; base_offsets is intentionally voided in backward; wgrad ownership cases are well-documented and correct.
transformer_engine/common/activation/scaled_activation.cu New CUDA kernels for scaled SwiGLU/ClampedSwiGLU/SReLU; vectorized with GLU-interleave support; warp-reduce for scale-gradient path is numerically correct.
transformer_engine/common/include/transformer_engine/activation.h Public C API additions for nvte_scaled_swiglu, nvte_scaled_clamped_swiglu, nvte_scaled_srelu and their backward variants; documentation is clear and consistent.
tests/pytorch/megacpp/test_grouped_mlp.py Comprehensive integration tests covering wgrad storage, split dtype/device, activation variants, bias, zero-expert edge cases, and delay_wgrad error path.
tests/cpp/operator/test_scaled_activation.cu C++ unit tests with CPU reference implementations for SwiGLU, ClampedSwiGLU, and SReLU forward/backward; covers interleave=0 and interleave=32 across multiple shapes.
transformer_engine/pytorch/ops/fused/init.py Correctly imports and exposes ForwardGroupedMLP_MegaCpp and BackwardGroupedMLP_MegaCpp; registration handled inside submodules via prepend=True.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Py as Python Autograd
    participant Fwd as ForwardGroupedMLP_MegaCpp
    participant Bwd as BackwardGroupedMLP_MegaCpp
    participant CPP as C++ forward/backward
    participant GEMM as cuBLAS Grouped GEMM
    participant Act as Scaled Activation Kernels

    Py->>Fwd: fuser_forward(input, split_sizes, act_scales)
    Fwd->>Fwd: resolve weights / bias / scratch
    Fwd->>CPP: tex.megacpp_grouped_mlp_forward
    CPP->>GEMM: FC1 grouped GEMM
    CPP->>Act: nvte_scaled_swiglu / clamped_swiglu / srelu
    CPP->>GEMM: FC2 grouped GEMM
    CPP-->>Fwd: output, offsets, fc1_preact, fc2_x
    Fwd->>Fwd: save_for_backward
    Fwd-->>Py: fc2_out

    Py->>Bwd: fuser_backward(grad_output)
    Bwd->>CPP: tex.megacpp_grouped_mlp_backward
    CPP->>GEMM: FC2 wgrad
    CPP->>GEMM: FC2 dgrad
    CPP->>Act: nvte_scaled_dswiglu / dsrelu
    CPP->>GEMM: FC1 wgrad
    CPP->>GEMM: FC1 dgrad
    CPP-->>Bwd: grad_input, fc1_dy, grad_act_scales, wgrads
    Bwd->>Bwd: compute_grouped_dbias (Triton)
    Bwd-->>Py: grad_input, grad_params, grad_extra_inputs
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Py as Python Autograd
    participant Fwd as ForwardGroupedMLP_MegaCpp
    participant Bwd as BackwardGroupedMLP_MegaCpp
    participant CPP as C++ forward/backward
    participant GEMM as cuBLAS Grouped GEMM
    participant Act as Scaled Activation Kernels

    Py->>Fwd: fuser_forward(input, split_sizes, act_scales)
    Fwd->>Fwd: resolve weights / bias / scratch
    Fwd->>CPP: tex.megacpp_grouped_mlp_forward
    CPP->>GEMM: FC1 grouped GEMM
    CPP->>Act: nvte_scaled_swiglu / clamped_swiglu / srelu
    CPP->>GEMM: FC2 grouped GEMM
    CPP-->>Fwd: output, offsets, fc1_preact, fc2_x
    Fwd->>Fwd: save_for_backward
    Fwd-->>Py: fc2_out

    Py->>Bwd: fuser_backward(grad_output)
    Bwd->>CPP: tex.megacpp_grouped_mlp_backward
    CPP->>GEMM: FC2 wgrad
    CPP->>GEMM: FC2 dgrad
    CPP->>Act: nvte_scaled_dswiglu / dsrelu
    CPP->>GEMM: FC1 wgrad
    CPP->>GEMM: FC1 dgrad
    CPP-->>Bwd: grad_input, fc1_dy, grad_act_scales, wgrads
    Bwd->>Bwd: compute_grouped_dbias (Triton)
    Bwd-->>Py: grad_input, grad_params, grad_extra_inputs
Loading

Reviews (1): Last reviewed commit: "integrate fused scaled swiglu and srelu" | Re-trigger Greptile

Comment on lines +33 to +46
def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool:
"""Whether megacpp is a valid candidate for the active quantization recipe.

Today the C++ implementation is BF16/FP16-only, so only the no-recipe path
is supported. Returning False for FP8 recipes is intentional: it leaves the
op list unchanged so the existing MXFP8/NVFP4 CuTe DSL fusers can match.
Future MXFP8/NVFP4 support should be enabled by changing this predicate,
not by reordering fusion registrations.
"""
if recipe is None:
return True
if recipe.mxfp8() or recipe.nvfp4():
return False
return False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The if recipe.mxfp8() or recipe.nvfp4(): return False branch is dead code — any non-None recipe that is not mxfp8/nvfp4 falls through to the identical final return False. The function is equivalent to return recipe is None; the current form misleads readers about which cases are being explicitly handled and can confuse future engineers adding support for a new recipe type.

Suggested change
def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool:
"""Whether megacpp is a valid candidate for the active quantization recipe.
Today the C++ implementation is BF16/FP16-only, so only the no-recipe path
is supported. Returning False for FP8 recipes is intentional: it leaves the
op list unchanged so the existing MXFP8/NVFP4 CuTe DSL fusers can match.
Future MXFP8/NVFP4 support should be enabled by changing this predicate,
not by reordering fusion registrations.
"""
if recipe is None:
return True
if recipe.mxfp8() or recipe.nvfp4():
return False
return False
def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool:
"""Whether megacpp is a valid candidate for the active quantization recipe.
Today the C++ implementation is BF16/FP16-only, so only the no-recipe path
is supported. Any quantized recipe (MXFP8, NVFP4, or future types) returns
False, leaving the op list unchanged so recipe-specific fusers can match.
Future quantized-compute support should be enabled by adding an explicit
`return True` here for the new recipe type, not by reordering fusions.
"""
return recipe is None

Comment on lines +49 to +54
@functools.lru_cache(maxsize=None)
def _cached_grouped_gemm_scratch(
num_groups: int,
device_index: int,
_stream_handle: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The @functools.lru_cache(maxsize=None) with a CUDA stream handle as part of the key creates an unbounded cache. Each unique (num_groups, device_index, stream_handle) triplet permanently holds three CUDA tensors. In pipeline-parallel configurations or test suites that construct many torch.cuda.Stream() objects over time, destroyed streams' allocations are never freed because lru_cache holds the only live reference. Bounding the cache (e.g., maxsize=64) caps the worst-case retained GPU memory.

Suggested change
@functools.lru_cache(maxsize=None)
def _cached_grouped_gemm_scratch(
num_groups: int,
device_index: int,
_stream_handle: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
@functools.lru_cache(maxsize=64)
def _cached_grouped_gemm_scratch(
num_groups: int,
device_index: int,
_stream_handle: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

Comment on lines +276 to +279
requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs)
input_requires_grad = requires_grad
fc1_weight_requires_grad = requires_grad and fc1_weight_param.requires_grad
fc2_weight_requires_grad = requires_grad and fc2_weight_param.requires_grad

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 input_requires_grad is set to the generic requires_grad flag, which is True whenever any of the three op contexts requires a gradient — including the weight-only case. When the input tensor is frozen, the C++ backward still computes the full FC1 dgrad GEMM and discards the result. Using fc1_ctx.input_requires_grad as an additional gate matches the existing fuser convention and avoids the wasted computation.

Suggested change
requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs)
input_requires_grad = requires_grad
fc1_weight_requires_grad = requires_grad and fc1_weight_param.requires_grad
fc2_weight_requires_grad = requires_grad and fc2_weight_param.requires_grad
requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs)
input_requires_grad = requires_grad and fc1_ctx.input_requires_grad
fc1_weight_requires_grad = requires_grad and fc1_weight_param.requires_grad
fc2_weight_requires_grad = requires_grad and fc2_weight_param.requires_grad

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +91 to +92
if _delay_wgrad(fc_op, ctx):
raise ValueError("megacpp grouped MLP does not support delay_wgrad_compute=True.")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 delay_wgrad_compute=True incompatibility is detected inside fuser_backward, which runs during .backward() — after the forward pass has already executed. A user who constructs a megacpp-fused model with delay_wgrad_compute=True will get through an entire forward step before hitting the ValueError. Moving this check to ForwardGroupedMLP_MegaCpp.__init__ or fuse_forward_megacpp_ops would surface the error at model-construction or fusion time instead.

@zhongbozhu

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants