Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import pytest
import random

import torch
import torch.nn as nn
from torch.nn import Parameter

from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager,
get_align_size_for_quantization,
)
import torch
import torch.nn as nn
from torch.nn import Parameter

Comment on lines +15 to +18
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Is there a reason we're reordering? If the import order causes problems, then that's a bug we need to fix. Otherwise, this ordering seems strangely unmotivated and haphazard. It's also considered good Python style to put third party imports before local imports (PEP 8).

from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
Expand Down Expand Up @@ -2860,7 +2860,8 @@ def _make_grouped_tensor_uniform(
@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"])
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate) -> None:
@pytest.mark.parametrize("use_bias_scale", [False, True])
def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate, use_bias_scale) -> None:
if tex.get_cublasLt_version() < 130300:
pytest.skip("Grouped GEMM requires cuBLAS 13.3+.")
if torch.cuda.get_device_capability() < (10, 0):
Expand Down Expand Up @@ -2914,12 +2915,23 @@ def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate) -> No
if case != "discrete_out"
else None
)
bias_scale = None
if use_bias_scale and bias is not None and layout != "NT":
bias_scale = torch.randn(m, device="cuda", dtype=torch.float32)
# Bias add in grouped kernel accumulates in FP32 for BF16/FP16.
out_ref = (
[(o.float() + b.float()).to(dtype) for o, b in zip(out_ref_no_bias, bias)]
if bias is not None
else out_ref_no_bias
)
if bias is not None:
if bias_scale is not None:
offset = 0
out_ref = []
for i in range(z):
ms = m_sizes[i]
s = bias_scale[offset : offset + ms].unsqueeze(-1)
out_ref.append((out_ref_no_bias[i].float() + bias[i].float() * s).to(dtype))
offset += ms
else:
out_ref = [(o.float() + b.float()).to(dtype) for o, b in zip(out_ref_no_bias, bias)]
else:
out_ref = out_ref_no_bias
# Create grouped tensors based on case
device = A[0].device
grouped_A = A
Expand Down Expand Up @@ -2983,6 +2995,7 @@ def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate) -> No
layout=layout,
accumulate=accumulate,
bias=grouped_bias,
bias_scale=bias_scale,
)
out_grouped_no_bias = (
grouped_out_no_bias
Expand All @@ -2995,11 +3008,23 @@ def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate) -> No
else grouped_out_bias.split_into_quantized_tensors()
)

out_grouped_manual_bias = (
[(o.float() + b.float()).to(dtype) for o, b in zip(out_grouped_no_bias, bias)]
if bias is not None
else out_grouped_no_bias
)
if bias is not None:
if bias_scale is not None:
out_grouped_manual_bias = []
offset = 0
for i in range(z):
ms = m_sizes[i]
s = bias_scale[offset : offset + ms].unsqueeze(-1)
out_grouped_manual_bias.append(
(out_grouped_no_bias[i].float() + bias[i].float() * s).to(dtype)
)
offset += ms
else:
out_grouped_manual_bias = [
(o.float() + b.float()).to(dtype) for o, b in zip(out_grouped_no_bias, bias)
]
else:
out_grouped_manual_bias = out_grouped_no_bias
tols = dtype_tols(dtype)
for o, o_ref in zip(out_grouped_no_bias, out_ref_no_bias):
torch.testing.assert_close(o, o_ref, **tols)
Expand Down
193 changes: 147 additions & 46 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <algorithm>
#include <cstdint>
#include <type_traits>
#include <vector>

#include "../common.h"
Expand Down Expand Up @@ -845,43 +846,102 @@ __forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorSha
}
}

// Kernel that performs bias addition to the Grouped GEMM output tensors.
// Bias itself is a grouped tensor with the collections of same number of tensors
// as the output tensors.
template <typename T, int kVec>
__global__ void grouped_bias_add_kernel(char *d_base, const char *bias_base, TensorShapeInfo d_meta,
TensorShapeInfo bias_meta, size_t num_tensors) {
const size_t tensor_idx = blockIdx.x;
if (tensor_idx >= num_tensors) return;
// Kernel that performs (optionally scaled) bias addition to Grouped GEMM output tensors.
// 2D grid: blockIdx.x = row chunk, blockIdx.y = column chunk.
// Each block loads bias once for its column chunk and sweeps its rows
// with direct vectorized load-add-store on d.
template <typename T, int kVec, bool UseScale, int kBlockDim, int kRowsPerBlock>
__global__ void grouped_bias_add_kernel(char *__restrict__ d_base,
const char *__restrict__ bias_base,
const float *__restrict__ scale_base,
TensorShapeInfo d_meta, int n, int total_rows,
int num_tensors) {
using VecStorage = transformer_engine::VectorizedStorage<T, kVec>;
using VecType = typename VecStorage::LType;

const int64_t m = d_meta.first_dims ? d_meta.first_dims[tensor_idx] : d_meta.uniform_first;
const int64_t n = d_meta.last_dims ? d_meta.last_dims[tensor_idx] : d_meta.uniform_last;
constexpr int kMaxTensors = 257;
__shared__ int cumsum[kMaxTensors];
Comment on lines +862 to +863
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name is wrong.

Suggested change
constexpr int kMaxTensors = 257;
__shared__ int cumsum[kMaxTensors];
constexpr int kMaxTensors = 256;
__shared__ int cumsum[kMaxTensors + 1];


const int64_t d_offset = compute_grouped_tensor_offset(d_meta, tensor_idx);
const int64_t bias_offset = compute_grouped_tensor_offset(bias_meta, tensor_idx);
const int tid = static_cast<int>(threadIdx.x);
const int block_dim = static_cast<int>(blockDim.x);
const int row_bid = static_cast<int>(blockIdx.x);
const int col_bid = static_cast<int>(blockIdx.y);

auto *d_ptr = reinterpret_cast<T *>(d_base + d_offset * sizeof(T));
const auto *bias_ptr = reinterpret_cast<const T *>(bias_base + bias_offset * sizeof(T));
const int row_start = row_bid * kRowsPerBlock;
const int row_end = min(row_start + kRowsPerBlock, total_rows);
if (row_start >= total_rows) return;

const int block_cols = block_dim * kVec;
const int col = col_bid * block_cols + tid * kVec;
if (col >= n) return;

// Build cumulative row prefix-sum in shared memory.
if (tid == 0) cumsum[0] = 0;
for (int i = tid; i < num_tensors; i += block_dim) {
cumsum[i + 1] =
static_cast<int>(d_meta.first_dims ? d_meta.first_dims[i] : d_meta.uniform_first);
}
__syncthreads();
if (tid == 0) {
for (int t = 1; t <= num_tensors; t++) cumsum[t] += cumsum[t - 1];
}
__syncthreads();

T *__restrict__ d = reinterpret_cast<T *>(d_base);
const T *__restrict__ bias = reinterpret_cast<const T *>(bias_base);

// Binary search for the starting row's tensor.
int tensor_idx;
{
int lo = 0, hi = num_tensors;
while (lo < hi) {
int mid = (lo + hi) >> 1;
if (cumsum[mid + 1] <= row_start)
lo = mid + 1;
else
hi = mid;
}
tensor_idx = lo;
}
int bias_idx = tensor_idx * n;
Comment on lines +893 to +906
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we benchmarked whether this binary search is any better than just scanning through the tensors. Computing the cumsums is still O(n), so we're not improving the asymptotics. We're also introducing thread syncs and shared memory accesses.


const int64_t elements = m * n;
const int64_t vec_count = elements / kVec;
using VecStorage = transformer_engine::VectorizedStorage<T, kVec>;
using VecType = typename VecStorage::LType;
transformer_engine::VectorizedLoader<T, kVec, true> loader(d_ptr, elements);
transformer_engine::VectorizedStorer<T, kVec, true> storer(d_ptr, elements);
const int64_t vec_id = static_cast<int64_t>(blockIdx.y) * blockDim.x + threadIdx.x;
if (vec_id >= vec_count) return;
const int64_t vec_start = vec_id * kVec;
const int64_t col = vec_start % n;
loader.load(vec_id, elements);
const auto *b_vec = reinterpret_cast<const VecType *>(bias_ptr + col);
VecStorage b_in;
b_in.scratch_.aligned = *b_vec;
b_in.scratch_.aligned = *reinterpret_cast<const VecType *>(bias + bias_idx + col);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value is immediately wiped out in the loop. I guess the compiler might be smart enough not to do an unnecessary memory access, but it makes the code harder to read.


// Walk tensor segments within this block's row range.
int seg_start = row_start;
while (seg_start < row_end) {
while (tensor_idx < num_tensors - 1 && cumsum[tensor_idx + 1] <= seg_start) {
tensor_idx++;
bias_idx += n;
}
b_in.scratch_.aligned = *reinterpret_cast<const VecType *>(bias + bias_idx + col);
const int seg_end = min(cumsum[tensor_idx + 1], row_end);

for (int row = seg_start; row < seg_end; row++) {
T *d_ptr = d + row * n + col;
VecStorage d_in;
d_in.scratch_.aligned = *reinterpret_cast<const VecType *>(d_ptr);

[[maybe_unused]] float s_val;
if constexpr (UseScale) s_val = scale_base[row];

#pragma unroll
for (int i = 0; i < kVec; ++i) {
storer.separate()[i] = loader.separate()[i] + b_in.scratch_.separate[i];
for (int i = 0; i < kVec; ++i) {
if constexpr (UseScale) {
d_in.scratch_.separate[i] =
static_cast<T>(fmaf(static_cast<float>(b_in.scratch_.separate[i]), s_val,
static_cast<float>(d_in.scratch_.separate[i])));
} else {
d_in.scratch_.separate[i] = static_cast<T>(static_cast<float>(d_in.scratch_.separate[i]) +
static_cast<float>(b_in.scratch_.separate[i]));
}
}
*reinterpret_cast<VecType *>(d_ptr) = d_in.scratch_.aligned;
}

seg_start = seg_end;
}
storer.store(vec_id, elements);
}

// Single kernel that sets up all GEMM parameters.
Expand Down Expand Up @@ -1308,12 +1368,13 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa,
}

void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias,
cudaStream_t stream) {
const NVTETensor scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_grouped_bias_add);
using namespace transformer_engine;

const GroupedTensor *outputD = convertNVTEGroupedTensorCheck(output);
const GroupedTensor *bias_tensor = convertNVTEGroupedTensorCheck(bias);
const Tensor *scale_tensor = convertNVTETensorCheck(scale);

NVTE_CHECK(outputD->num_tensors >= 1, "Grouped bias add: number of tensors must be at least 1");
NVTE_CHECK(outputD->num_tensors == bias_tensor->num_tensors,
Expand All @@ -1330,27 +1391,67 @@ void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTens
"Grouped bias add requires uniform last dim for output and bias");
NVTE_CHECK(outputD->get_common_last_dim() == bias_tensor->get_common_last_dim(),
"Grouped bias add: output and bias last dims must match");
constexpr int kVec = 4;
NVTE_CHECK(outputD->get_common_last_dim() % kVec == 0,
"Grouped bias add requires last dim divisible by ", kVec);

const float *scale_ptr = nullptr;
if (scale_tensor->data.dptr != nullptr) {
NVTE_CHECK(scale_tensor->dtype() == DType::kFloat32, "Grouped bias add: scale must be float32");
NVTE_CHECK(scale_tensor->data.shape.size() == 1, "Grouped bias add: scale must be 1D, got ",
scale_tensor->data.shape.size(), "D");
const size_t total_rows = static_cast<size_t>(outputD->logical_shape.data[0]);
NVTE_CHECK(scale_tensor->data.shape[0] == total_rows, "Grouped bias add: scale size (",
scale_tensor->data.shape[0], ") must equal total rows (", total_rows, ")");
scale_ptr = static_cast<const float *>(scale_tensor->data.dptr);
}

const TensorShapeInfo d_meta = TensorShapeInfo::from_tensor(outputD);
const TensorShapeInfo bias_meta = TensorShapeInfo::from_tensor(bias_tensor);

const DType dtype = outputD->dtype();
constexpr int kThreads = 256;
const size_t total_elements = static_cast<size_t>(outputD->logical_shape.data[0]) *
static_cast<size_t>(outputD->logical_shape.data[1]);
const size_t total_vec_count = (total_elements + kVec - 1) / kVec;
int blocks_per_tensor = static_cast<int>((total_vec_count + kThreads - 1) / kThreads);
const dim3 grid(outputD->num_tensors, blocks_per_tensor);

const int num_tensors = static_cast<int>(outputD->num_tensors);
NVTE_CHECK(num_tensors <= 256, "Grouped bias add supports at most 256 tensors, got ",
num_tensors);
const int total_rows = static_cast<int>(outputD->logical_shape.data[0]);
const int n = static_cast<int>(outputD->get_common_last_dim());

// Use 128-bit vector loads: kVec=8 for 2-byte types (bf16/fp16), kVec=4 for fp32.
const size_t elem_size = typeToSize(dtype);
const int kVec = (elem_size <= 2) ? 8 : 4;
NVTE_CHECK(n % kVec == 0, "Grouped bias add requires last dim divisible by ", kVec);

constexpr int kRowsPerBlock = 8;
const int block_cols = kThreads * kVec;
const int col_blocks = (n + block_cols - 1) / block_cols;
const int row_blocks = (total_rows + kRowsPerBlock - 1) / kRowsPerBlock;
const dim3 grid(row_blocks, col_blocks);
const dim3 block(kThreads);

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, T, {
grouped_bias_add_kernel<T, kVec><<<grid, block, 0, stream>>>(
static_cast<char *>(outputD->data.dptr), static_cast<const char *>(bias_tensor->data.dptr),
d_meta, bias_meta, outputD->num_tensors);
});
auto launch = [&](auto use_scale_tag) {
constexpr bool kUseScale = decltype(use_scale_tag)::value;
if (elem_size <= 2) {
constexpr int kV = 8;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, T, {
grouped_bias_add_kernel<T, kV, kUseScale, kThreads, kRowsPerBlock>
<<<grid, block, 0, stream>>>(static_cast<char *>(outputD->data.dptr),
static_cast<const char *>(bias_tensor->data.dptr),
scale_ptr, d_meta, n, total_rows, num_tensors);
});
} else {
constexpr int kV = 4;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, T, {
grouped_bias_add_kernel<T, kV, kUseScale, kThreads, kRowsPerBlock>
<<<grid, block, 0, stream>>>(static_cast<char *>(outputD->data.dptr),
static_cast<const char *>(bias_tensor->data.dptr),
scale_ptr, d_meta, n, total_rows, num_tensors);
});
}
};

if (scale_ptr != nullptr) {
launch(std::true_type{});
} else {
launch(std::false_type{});
}

NVTE_CHECK_CUDA(cudaGetLastError());
}
Expand Down Expand Up @@ -1392,7 +1493,7 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa,
}

void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias,
cudaStream_t stream) {
const NVTETensor scale, cudaStream_t stream) {
NVTE_ERROR("nvte_grouped_bias_add requires cuBLAS 13.3+, but compile-time cuBLAS version is ",
CUBLAS_VERSION, ". Please upgrade to cuBLAS 13.3 (shipped with CUDA 13.2) or newer.");
}
Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/common/include/transformer_engine/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,14 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa,
NVTETensor workspace_setup, NVTETensor workspace_cublas,
NVTEGroupedMatmulConfig config, cudaStream_t stream);

/*! \brief Grouped bias add for grouped GEMM outputs.
/*! \brief Grouped Bias add for grouped GEMM outputs.
*
* When \p scale is a valid tensor: output[row,col] += bias[col] * scale[row],
* When \p scale is empty/null: output[row,col] += bias[col].
* Requires uniform last-dimension across all output tensors and bias tensors.
*/
void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias,
cudaStream_t stream);
const NVTETensor scale, cudaStream_t stream);
Comment on lines 438 to +439
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes more sense to create a separate API for nvte_grouped_scaled_bias_add. Grouped bias is a natural generalization of linear layer biases, but grouped scaled bias is less intuitive (especially that the biases are per-group, but the scales are per-token) and it should be treated as more exotic.


#ifdef __cplusplus
} // extern "C"
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def general_grouped_gemm_for_grouped_tensor(
accumulate: bool = False,
use_split_accumulator: bool = False,
bias=None,
bias_scale: Optional[torch.Tensor] = None,
grad: bool = False,
alpha: Optional[torch.Tensor] = None,
beta: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -365,6 +366,9 @@ def general_grouped_gemm_for_grouped_tensor(
"Apply bias manually after the GEMM."
)

if bias_scale is not None and bias is None:
raise ValueError("bias_scale requires bias to be provided.")

num_tensors = B.num_tensors
rowwise = B.rowwise_data
device = rowwise.device if rowwise is not None else B.columnwise_data.device
Expand Down Expand Up @@ -394,13 +398,17 @@ def general_grouped_gemm_for_grouped_tensor(
sm_count = get_sm_count()
sm_count = sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count)))

if bias_scale is None:
bias_scale = torch.empty(0, dtype=torch.float32, device=device)

Comment on lines +401 to +403
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid this overhead by making the tex function take an optional argument.

Suggested change
if bias_scale is None:
bias_scale = torch.empty(0, dtype=torch.float32, device=device)

return grouped_gemm_impl(
A,
transa,
B,
transb,
out,
bias,
bias_scale,
alpha,
beta,
workspace_setup,
Expand Down
Loading