Skip to content

Scaled Bias Add support after CUBLAS GGEMM#2885

Open
vthumbe1503 wants to merge 6 commits intoNVIDIA:mainfrom
vthumbe1503:scaled_bias_ggemm
Open

Scaled Bias Add support after CUBLAS GGEMM#2885
vthumbe1503 wants to merge 6 commits intoNVIDIA:mainfrom
vthumbe1503:scaled_bias_ggemm

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…imized and uses scales now

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title Scaled Bias Add support at the end of CUBLAS GGEMM Scaled Bias Add support after CUBLAS GGEMM Apr 15, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 15, 2026

Greptile Summary

This PR adds optional per-row scale support to nvte_grouped_bias_add, enabling output[row,col] += bias[col] * scale[row] after a cuBLAS grouped GEMM. The kernel is refactored from a tensor-per-block grid to a 2D row-chunk × col-chunk grid with a shared-memory prefix-sum for tensor boundary detection, and a compile-time UseScale template flag avoids runtime branching in the inner loop. The C++ and Python extension layers are consistently updated across all three grouped-GEMM variants (grouped_tensor, discrete_in, discrete_out).

Confidence Score: 5/5

Safe to merge; all remaining findings are P2 style/improvement suggestions with no runtime risk for current callers.

The core scaled-bias logic is correctly implemented: fmaf argument order matches documented semantics, shared-memory cumsum is correctly initialized and synchronized, and the empty-tensor sentinel correctly disables scaling. The two P2 findings (dead pre-loop bias load; missing tensor_offsets guard) do not affect current callers since bias GroupedTensors are always packed in the Python bindings.

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu — dead pre-loop load and missing tensor_offsets guard in nvte_grouped_bias_add.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Rewrites grouped_bias_add_kernel to support optional per-row scale; new 2D grid with shared-memory cumsum for tensor boundaries. Two P2 issues: dead pre-loop bias load, and bias offset computed as tensor_idx*n without validating absence of explicit tensor_offsets.
transformer_engine/pytorch/csrc/extensions/gemm.cpp Adds bias_scale parameter to all three grouped-tensor GEMM entry points; correctly wraps it as an NVTE tensor and passes it to nvte_grouped_bias_add after GIL release, consistently across all three variants.
transformer_engine/pytorch/cpp_extensions/gemm.py Adds bias_scale parameter to general_grouped_gemm_for_grouped_tensor; defaults to empty tensor when None, validates bias_scale requires bias, and passes through to the C++ extension.
transformer_engine/common/include/transformer_engine/gemm.h Adds const NVTETensor scale parameter to nvte_grouped_bias_add declaration with accurate docstring describing conditional application. Clean change.
transformer_engine/pytorch/csrc/extensions.h Adds at::Tensor bias_scale parameter to the three grouped GEMM extension function declarations, consistent with gemm.cpp implementation.
transformer_engine/pytorch/csrc/extensions/swizzle.cpp Adds NVTE_NVFP4_1D_SCALING to the maybe_swizzle_grouped_tensor guard so FP4 grouped tensors are swizzled; mechanically correct but unrelated to the PR's stated feature.
tests/pytorch/test_numerics.py Adds use_bias_scale parametrize to test_grouped_gemm_grouped_tensor; reference and manual-bias paths updated to compute d += bias * scale. NT layout silently excluded from bias_scale testing without explanation.

Sequence Diagram

sequenceDiagram
    participant Py as Python (gemm.py)
    participant Ext as C++ Extension (gemm.cpp)
    participant NVTE as nvte_grouped_gemm
    participant BiasKernel as nvte_grouped_bias_add

    Py->>Ext: general_grouped_gemm_for_grouped_tensor(A, B, out, bias, bias_scale)
    Ext->>Ext: prepare_grouped_gemm_config(alpha, beta, ...)
    Ext->>NVTE: nvte_grouped_gemm(A, B, C=D, D, alpha, beta, ...)
    NVTE-->>Ext: D = alpha * A @ B + beta * C
    alt bias is not None
        Ext->>BiasKernel: nvte_grouped_bias_add(D, bias, scale)
        Note over BiasKernel: Build shared cumsum for row-to-tensor map
        BiasKernel->>BiasKernel: grouped_bias_add_kernel UseScale=true/false
        Note over BiasKernel: D[row,col] += bias[col] * scale[row]
        BiasKernel-->>Ext: D updated in-place
    end
    Ext-->>Py: D (updated)
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

const size_t tensor_idx = blockIdx.y;
if (tensor_idx >= num_tensors) return;

const int64_t n = d_meta.last_dims ? d_meta.last_dims[0] : d_meta.uniform_last;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded index [0] instead of [tensor_idx]

d_meta.last_dims[0] works only because the pre-launch NVTE_CHECK(outputD->all_same_last_dim() ...) enforces a uniform last dimension. Using the hardcoded index removes the per-tensor correctness at a glance — a future reader (or a refactor that relaxes the uniform check) would not immediately see why [0] is used instead of [tensor_idx]. A comment linking this to the uniformity invariant would make this self-documenting.

Suggested change
const int64_t n = d_meta.last_dims ? d_meta.last_dims[0] : d_meta.uniform_last;
const int64_t n = d_meta.last_dims ? d_meta.last_dims[0] // uniform across tensors (checked)
: d_meta.uniform_last;

Comment on lines +873 to +882
int64_t scale_row_offset = 0;
if constexpr (UseScale) {
if (d_meta.first_dims) {
for (size_t i = 0; i < tensor_idx; i++) {
scale_row_offset += d_meta.first_dims[i];
}
} else {
scale_row_offset = static_cast<int64_t>(tensor_idx) * d_meta.uniform_first;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant per-thread scale_row_offset loop

Every thread in the block (all 256 of them) independently computes scale_row_offset by iterating up to tensor_idx times over d_meta.first_dims. Since tensor_idx == blockIdx.y, all threads in a block produce the same value. For large num_tensors, moving this into shared memory (computed once by thread 0 and shared) would avoid the redundant iterations. The broadcast access pattern through L1 is benign for small num_tensors, but is worth noting for scalability.

Comment on lines 341 to 347
std::optional<SwizzledGroupedScales> maybe_swizzle_grouped_tensor(GroupedTensorWrapper &input,
bool rowwise_usage,
bool columnwise_usage) {
if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) {
if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING &&
input.scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unrelated FP4 swizzle change — should be documented

This guard extension (adding NVTE_NVFP4_1D_SCALING) is a separate fix that enables grouped-tensor scale swizzling for FP4 inputs; it is unrelated to the Scaled Bias Add feature described in the PR title. nvte_swizzle_grouped_scaling_factors does handle FP4 in swizzle.cu, so the change is mechanically correct, but it would be helpful to document the motivation in the PR description or add a comment here explaining why FP4 also needs this path.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant