Skip to content

[PyTorch] Minor optimizations in fused grouped MLP#2888

Merged
ksivaman merged 2 commits intoNVIDIA:mainfrom
ksivaman:minor_opts_fused_mlp
Apr 16, 2026
Merged

[PyTorch] Minor optimizations in fused grouped MLP#2888
ksivaman merged 2 commits intoNVIDIA:mainfrom
ksivaman:minor_opts_fused_mlp

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

Small perf/cpu overhead improvements in fused grouped MLP

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring
  • Performance

Changes

  • Reduce number of casts via torch .to wherever possible.
  • Use fused kernel to calculate offsets from splits.
  • Calculate split points by indexing into offsets to remove an additional cumsum.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 15, 2026 20:49
@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci pytorch L0

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 15, 2026

Greptile Summary

This PR introduces minor CPU-overhead and performance optimizations in the fused grouped MLP forward and backward paths. The key changes are: (1) replacing two torch.cumsum + torch.cat calls with a single tex.splits_to_offsets fused GPU kernel, (2) eliminating redundant no-op .to() casts for split_sizes/split_points in the backward pass since the forward already saves them in the correct dtype, (3) hoisting scales.detach().to(torch.float32) to avoid a duplicate cast in the scale_bias branch, and (4) pulling bias-grad .to(dtype) casts outside if/else branches to apply once on the full tensor.

Confidence Score: 5/5

Safe to merge — all changes are semantically equivalent refactors with no correctness impact.

All optimizations are either no-ops being removed (backward casts on tensors already in the right dtype), or logically equivalent transformations (fused kernel vs. cumsum+cat, hoisted casts). The C++ implementation of splits_to_offsets was verified to produce int64 output matching the original base_offsets computation. No new code paths, no behavior changes for any dtype or shape.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Replaces two cumsum + cat ops with a single fused splits_to_offsets kernel and derives split_points by indexing into the result; semantically equivalent and correct.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Removes redundant backward casts (already correct dtype from forward), hoists float32 scale conversion, and consolidates dtype casts for bias grads — all semantically equivalent refactors.

Reviews (1): Last reviewed commit: "Merge branch 'main' into minor_opts_fuse..." | Re-trigger Greptile

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman ksivaman merged commit c9035a4 into NVIDIA:main Apr 16, 2026
10 of 13 checks passed
@ksivaman ksivaman deleted the minor_opts_fused_mlp branch April 16, 2026 12:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants