From 18d4d2c5c32519861d8faa122581d00a874b3426 Mon Sep 17 00:00:00 2001 From: zhongboz Date: Fri, 12 Jun 2026 17:28:54 -0700 Subject: [PATCH 1/8] support scaled swiglu, scaled srelu and scaled clamp swiglu Signed-off-by: zhongboz --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_scaled_activation.cu | 328 ++++++++++ transformer_engine/common/CMakeLists.txt | 2 + .../common/activation/scaled_activation.cu | 567 ++++++++++++++++++ .../include/transformer_engine/activation.h | 103 ++++ 5 files changed, 1001 insertions(+) create mode 100644 tests/cpp/operator/test_scaled_activation.cu create mode 100644 transformer_engine/common/activation/scaled_activation.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 9b67c09f34..d5c446fb48 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -24,6 +24,7 @@ add_executable(test_operator test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu + test_scaled_activation.cu test_normalization.cu test_normalization_mxfp8.cu test_memset.cu diff --git a/tests/cpp/operator/test_scaled_activation.cu b/tests/cpp/operator/test_scaled_activation.cu new file mode 100644 index 0000000000..1cb630a0bc --- /dev/null +++ b/tests/cpp/operator/test_scaled_activation.cu @@ -0,0 +1,328 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +enum class ScaledActivationCase { + kSwiGLU, + kClampedSwiGLU, + kSReLU, +}; + +constexpr float kClampedLimit = 1.3f; +constexpr float kClampedAlpha = 1.702f; +constexpr float kClampedLinearOffset = 0.5f; + +const char *activation_name(ScaledActivationCase activation) { + switch (activation) { + case ScaledActivationCase::kSwiGLU: + return "scaled_swiglu"; + case ScaledActivationCase::kClampedSwiGLU: + return "scaled_clamped_swiglu"; + case ScaledActivationCase::kSReLU: + return "scaled_srelu"; + } + return "unknown"; +} + +inline float sigmoid(const float x) { return 1.0f / (1.0f + expf(-x)); } + +inline float qgelu_alpha(const float x, const float alpha) { return x * sigmoid(alpha * x); } + +inline float dqgelu_alpha(const float x, const float alpha) { + const float sig = sigmoid(alpha * x); + return alpha * x * sig * (1.0f - sig) + sig; +} + +inline float silu_ref(const float x) { return x * sigmoid(x); } + +inline float dsilu_ref(const float x) { + const float sig = sigmoid(x); + return x * sig * (1.0f - sig) + sig; +} + +inline float srelu_ref(const float x) { return x > 0.0f ? x * x : 0.0f; } + +inline float dsrelu_ref(const float x) { return fmaxf(0.0f, 2.0f * x); } + +inline void glu_indices(const size_t row, const size_t col, const size_t hidden, + const int64_t interleave, size_t *act_idx, size_t *linear_idx) { + if (interleave > 0) { + const size_t block = col / static_cast(interleave); + const size_t lane = col % static_cast(interleave); + const size_t base = row * hidden * 2 + block * static_cast(interleave) * 2 + lane; + *act_idx = base; + *linear_idx = base + static_cast(interleave); + } else { + const size_t base = row * hidden * 2; + *act_idx = base + col; + *linear_idx = base + hidden + col; + } +} + +inline float gated_unscaled(const ScaledActivationCase activation, const float act_in, + const float linear_in) { + switch (activation) { + case ScaledActivationCase::kSwiGLU: + return silu_ref(act_in) * linear_in; + case ScaledActivationCase::kClampedSwiGLU: { + const float act = qgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha); + const float linear = + fminf(fmaxf(-kClampedLimit, linear_in), kClampedLimit) + kClampedLinearOffset; + return act * linear; + } + case ScaledActivationCase::kSReLU: + return srelu_ref(act_in); + } + return 0.0f; +} + +inline void gated_grads(const ScaledActivationCase activation, const float act_in, + const float linear_in, float *dact, float *dlinear, float *unscaled) { + switch (activation) { + case ScaledActivationCase::kSwiGLU: { + const float act = silu_ref(act_in); + *unscaled = act * linear_in; + *dact = dsilu_ref(act_in) * linear_in; + *dlinear = act; + return; + } + case ScaledActivationCase::kClampedSwiGLU: { + const bool dlinear_mask = linear_in <= kClampedLimit && linear_in >= -kClampedLimit; + const float act = qgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha); + const float dact_base = + act_in <= kClampedLimit ? dqgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha) + : 0.0f; + const float linear = + fminf(fmaxf(-kClampedLimit, linear_in), kClampedLimit) + kClampedLinearOffset; + *unscaled = act * linear; + *dact = dact_base * linear; + *dlinear = dlinear_mask ? act : 0.0f; + return; + } + case ScaledActivationCase::kSReLU: + *unscaled = srelu_ref(act_in); + *dact = dsrelu_ref(act_in); + *dlinear = 0.0f; + return; + } +} + +template +void compute_reference(ScaledActivationCase activation, const DataT *input, const ScaleT *scales, + const DataT *grad_output, DataT *output, DataT *grad_input, + DataT *grad_scales, const size_t rows, const size_t hidden, + const int64_t interleave, const bool compute_grad_scales) { + const bool is_gated = activation != ScaledActivationCase::kSReLU; + const size_t input_cols = is_gated ? hidden * 2 : hidden; + std::fill(grad_input, grad_input + rows * input_cols, static_cast(0.0f)); + + for (size_t row = 0; row < rows; ++row) { + const float scale = static_cast(scales[row]); + float scale_grad = 0.0f; + for (size_t col = 0; col < hidden; ++col) { + const size_t out_idx = row * hidden + col; + float unscaled = 0.0f; + float dact = 0.0f; + float dlinear = 0.0f; + if (is_gated) { + size_t act_idx = 0; + size_t linear_idx = 0; + glu_indices(row, col, hidden, interleave, &act_idx, &linear_idx); + const float act_in = static_cast(input[act_idx]); + const float linear_in = static_cast(input[linear_idx]); + unscaled = gated_unscaled(activation, act_in, linear_in); + gated_grads(activation, act_in, linear_in, &dact, &dlinear, &unscaled); + + const float scaled_grad = static_cast(grad_output[out_idx]) * scale; + grad_input[act_idx] = static_cast(scaled_grad * dact); + grad_input[linear_idx] = static_cast(scaled_grad * dlinear); + } else { + const float x = static_cast(input[out_idx]); + unscaled = srelu_ref(x); + const float scaled_grad = static_cast(grad_output[out_idx]) * scale; + grad_input[out_idx] = static_cast(scaled_grad * dsrelu_ref(x)); + } + + output[out_idx] = static_cast(unscaled * scale); + scale_grad += static_cast(grad_output[out_idx]) * unscaled; + } + if (compute_grad_scales) { + grad_scales[row] = static_cast(scale_grad); + } + } +} + +template +void run_scaled_activation_test(ScaledActivationCase activation, const size_t rows, + const size_t hidden, const int64_t interleave, + const bool compute_grad_scales) { + using namespace test; + const DType data_type = TypeInfo::dtype; + const DType scale_type = TypeInfo::dtype; + const bool is_gated = activation != ScaledActivationCase::kSReLU; + const size_t input_cols = is_gated ? hidden * 2 : hidden; + + Tensor input("input", std::vector{rows, input_cols}, data_type); + Tensor scales("act_scales", std::vector{rows}, scale_type); + Tensor output("output", std::vector{rows, hidden}, data_type); + Tensor grad_output("grad_output", std::vector{rows, hidden}, data_type); + Tensor grad_input("grad_input", std::vector{rows, input_cols}, data_type); + Tensor grad_scales("grad_scales", std::vector{rows}, data_type); + + fillUniform(&input); + fillUniform(&scales); + fillUniform(&grad_output); + + std::unique_ptr ref_output = std::make_unique(rows * hidden); + std::unique_ptr ref_grad_input = std::make_unique(rows * input_cols); + std::unique_ptr ref_grad_scales = std::make_unique(rows); + + compute_reference(activation, input.rowwise_cpu_dptr(), scales.rowwise_cpu_dptr(), + grad_output.rowwise_cpu_dptr(), ref_output.get(), + ref_grad_input.get(), ref_grad_scales.get(), rows, hidden, interleave, + compute_grad_scales); + + switch (activation) { + case ScaledActivationCase::kSwiGLU: + nvte_scaled_swiglu(input.data(), scales.data(), output.data(), interleave, 0); + nvte_scaled_dswiglu(grad_output.data(), input.data(), scales.data(), grad_input.data(), + compute_grad_scales ? grad_scales.data() : nullptr, interleave, 0); + break; + case ScaledActivationCase::kClampedSwiGLU: + nvte_scaled_clamped_swiglu(input.data(), scales.data(), output.data(), kClampedLimit, + kClampedAlpha, kClampedLinearOffset, interleave, 0); + nvte_scaled_clamped_dswiglu( + grad_output.data(), input.data(), scales.data(), grad_input.data(), + compute_grad_scales ? grad_scales.data() : nullptr, kClampedLimit, kClampedAlpha, + kClampedLinearOffset, interleave, 0); + break; + case ScaledActivationCase::kSReLU: + nvte_scaled_srelu(input.data(), scales.data(), output.data(), 0); + nvte_scaled_dsrelu(grad_output.data(), input.data(), scales.data(), grad_input.data(), + compute_grad_scales ? grad_scales.data() : nullptr, 0); + break; + } + + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(data_type); + if (data_type == DType::kFloat32) { + atol = 5e-5; + rtol = 5e-5; + } + compareResults("scaled_activation_output", output, ref_output.get(), atol, rtol); + compareResults("scaled_activation_grad_input", grad_input, ref_grad_input.get(), atol, rtol); + if (compute_grad_scales) { + compareResults("scaled_activation_grad_scales", grad_scales, ref_grad_scales.get(), atol, rtol); + } +} + +class ScaledActivationTest + : public ::testing::TestWithParam< + std::tuple, int64_t, + bool>> { +}; + +std::string test_name_generator( + const testing::TestParamInfo &info) { + const auto activation = std::get<0>(info.param); + const auto data_type = std::get<1>(info.param); + const auto scale_type = std::get<2>(info.param); + const auto shape = std::get<3>(info.param); + const auto interleave = std::get<4>(info.param); + const auto compute_grad_scales = std::get<5>(info.param); + return std::string(activation_name(activation)) + "_data_" + test::typeName(data_type) + + "_scale_" + test::typeName(scale_type) + "_m_" + std::to_string(shape.first) + "_h_" + + std::to_string(shape.second) + "_interleave_" + std::to_string(interleave) + + (compute_grad_scales ? "_with_scale_grad" : "_no_scale_grad"); +} + +} // namespace + +TEST_P(ScaledActivationTest, ForwardBackward) { + const auto activation = std::get<0>(GetParam()); + const auto data_type = std::get<1>(GetParam()); + const auto scale_type = std::get<2>(GetParam()); + const auto shape = std::get<3>(GetParam()); + const auto interleave = std::get<4>(GetParam()); + const auto compute_grad_scales = std::get<5>(GetParam()); + + if (activation == ScaledActivationCase::kSReLU && interleave != 0) { + GTEST_SKIP() << "SReLU is not a GLU activation."; + } + if (activation != ScaledActivationCase::kSReLU && interleave > 0 && + shape.second % static_cast(interleave) != 0) { + GTEST_SKIP() << "Hidden size must be divisible by GLU interleave."; + } + + using namespace test; + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(data_type, DataT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(scale_type, ScaleT, { + run_scaled_activation_test(activation, shape.first, shape.second, interleave, + compute_grad_scales); + }); + }); +} + +// Test axes (the six tuple elements consumed by ScaledActivationTest): +// 1. Activation : SwiGLU and ClampedSwiGLU are gated (input is [M, 2H]); +// SReLU is unary (input is [M, H], no gate split). +// 2. Data dtype : dtype of the activation input/output tensors. +// 3. Scale dtype : dtype of act_scales / grad_act_scales. +// 4. Shape {rows, hidden}: rows = M (tokens), hidden = H (output width; gated input is 2H). +// 5. GLU interleave : 0 = contiguous [a | b]; 32 = interleaved a/b blocks. Only valid +// for gated activations with hidden % 32 == 0; SReLU skips != 0. +// 6. compute_grad_scales : whether the backward also reduces grad_act_scales. + +// Regular shapes: hidden is a multiple of 32, so the interleaved (32) layout is exercised +// alongside the contiguous (0) layout. +// Regular shapes (hidden % 32 == 0) and weird/irregular shapes (tiny, prime, non-32-aligned) +// share one instantiation. Interleave is swept over {0, 32}; invalid combinations -- SReLU with +// any nonzero interleave, or a gated activation whose hidden is not divisible by the interleave -- +// are skipped at runtime by the GTEST_SKIP guards in the test body. +INSTANTIATE_TEST_SUITE_P( + OperatorTest_ScaledActivation, ScaledActivationTest, + ::testing::Combine( + ::testing::Values(ScaledActivationCase::kSwiGLU, ScaledActivationCase::kClampedSwiGLU, + ScaledActivationCase::kSReLU), + ::testing::Values(DType::kFloat32, DType::kBFloat16), // data dtype + ::testing::Values(DType::kFloat32, DType::kBFloat16), // scale dtype + ::testing::Values(std::pair{17, 64}, // odd rows, aligned hidden + std::pair{8, 96}, // 96 = 3 * 32 + std::pair{32, 32}, // minimal aligned square + std::pair{128, 128}, // square + std::pair{64, 256}, // wide hidden + std::pair{256, 64}, // many rows, narrow hidden + std::pair{128, 512}, // FFN-ish width + std::pair{1, 1}, // single element + std::pair{1, 96}, // single row + std::pair{96, 1}, // single hidden column + std::pair{3, 7}, // tiny primes + std::pair{13, 100}, // non-power-of-two + std::pair{7, 257}, // prime, odd hidden + std::pair{33, 65}, // odd dims + std::pair{129, 31}), // odd rows, hidden < 32 + ::testing::Values(0, 32), // contiguous + interleaved + ::testing::Values(false, true)), // grad_act_scales off / on + test_name_generator); diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8f96432ed8..b4ba17e048 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -255,6 +255,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources activation/relu_dbias.cu activation/relu_grouped.cu activation/relu_grouped_dbias.cu + activation/scaled_activation.cu activation/swiglu.cu activation/swiglu_dbias.cu activation/swiglu_grouped.cu @@ -513,6 +514,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) activation/relu_dbias.cu activation/relu_grouped.cu activation/relu_grouped_dbias.cu + activation/scaled_activation.cu activation/swiglu.cu activation/swiglu_dbias.cu activation/swiglu_grouped.cu diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu new file mode 100644 index 0000000000..176253edc2 --- /dev/null +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -0,0 +1,567 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* Scaled activations: apply an activation, multiply by a per-row scale + * (act_scales[row]), do all math in fp32, and cast once at the store. The + * backward path optionally also reduces the gradient of the per-row scale. + * + * The six __global__ kernels below: + * + * # | Kernel | Activation | Dir | grad_act_scales | Launch + * ---+-----------------------------------------------+------------------------+-----+-----------------+-------------------- + * 1 | scaled_gated_forward_kernel | SwiGLU / ClampedSwiGLU | fwd | -- | flat element grid + * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | flat element grid + * 3 | scaled_gated_backward_kernel | SwiGLU / ClampedSwiGLU | bwd | no | flat element grid + * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | flat element grid + * 5 | scaled_gated_backward_with_scale_grad_kernel | SwiGLU / ClampedSwiGLU | bwd | yes | one block per row + * 6 | scaled_srelu_backward_with_scale_grad_kernel | SReLU | bwd | yes | one block per row + * + * The "with scale grad" variants compute grad_act_scales[row] = sum_j dY * unscaled, + * a per-row reduction that requires the one-block-per-row launch; when + * grad_act_scales is null the cheaper flat element-wise grid is used instead. + */ + +#include + +#include + +#include "../common.h" +#include "../util/math.h" + +namespace transformer_engine { +namespace { + +enum class ScaledActivation { + kSwiGLU, + kClampedSwiGLU, + kSReLU, +}; + +__device__ __forceinline__ void glu_input_indices(const size_t row, const size_t col, + const size_t hidden, + const int64_t glu_interleave_size, + size_t *act_idx, size_t *linear_idx) { + if (glu_interleave_size > 0) { + const size_t interleave = static_cast(glu_interleave_size); + const size_t block = col / interleave; + const size_t lane = col % interleave; + const size_t base = row * hidden * 2 + block * interleave * 2 + lane; + *act_idx = base; + *linear_idx = base + interleave; + } else { + const size_t base = row * hidden * 2; + *act_idx = base + col; + *linear_idx = base + hidden + col; + } +} + +template +__device__ __forceinline__ float gated_forward_value(const float act_in, const float linear_in, + const ClampedSwiGLUParam ¶m) { + if constexpr (Act == ScaledActivation::kSwiGLU) { + Empty empty = {}; + return silu(act_in, empty) * linear_in; + } else { + const float linear = + fminf(fmaxf(-param.limit, linear_in), param.limit) + param.glu_linear_offset; + return clamped_silu(act_in, param) * linear; + } +} + +template +__device__ __forceinline__ void gated_backward_values(const float act_in, const float linear_in, + const ClampedSwiGLUParam ¶m, + float *dact, float *dlinear, + float *unscaled) { + if constexpr (Act == ScaledActivation::kSwiGLU) { + Empty empty = {}; + const float act = silu(act_in, empty); + *unscaled = act * linear_in; + *dact = dsilu(act_in, empty) * linear_in; + *dlinear = act; + } else { + const bool dlinear_mask = linear_in <= param.limit && linear_in >= -param.limit; + const float linear = + fminf(fmaxf(-param.limit, linear_in), param.limit) + param.glu_linear_offset; + const float act = clamped_silu(act_in, param); + *unscaled = act * linear; + *dact = clamped_dsilu(act_in, param) * linear; + *dlinear = dlinear_mask ? act : 0.0f; + } +} + +template +__global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *act_scales, + OutputT *output, const size_t rows, + const size_t hidden, + const int64_t glu_interleave_size, + const ClampedSwiGLUParam param) { + const size_t total = rows * hidden; + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; + idx += gridDim.x * blockDim.x) { + const size_t row = idx / hidden; + const size_t col = idx % hidden; + size_t act_idx = 0; + size_t linear_idx = 0; + glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); + + const float unscaled = gated_forward_value(static_cast(input[act_idx]), + static_cast(input[linear_idx]), param); + const float scale = static_cast(act_scales[row]); + output[idx] = static_cast(unscaled * scale); + } +} + +template +__global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *act_scales, + OutputT *output, const size_t rows, + const size_t hidden) { + const size_t total = rows * hidden; + Empty empty = {}; + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; + idx += gridDim.x * blockDim.x) { + const size_t row = idx / hidden; + const float unscaled = srelu(static_cast(input[idx]), empty); + const float scale = static_cast(act_scales[row]); + output[idx] = static_cast(unscaled * scale); + } +} + +template +__global__ void scaled_gated_backward_kernel(const GradT *grad_output, const InputT *input, + const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden, + const int64_t glu_interleave_size, + const ClampedSwiGLUParam param) { + const size_t total = rows * hidden; + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; + idx += gridDim.x * blockDim.x) { + const size_t row = idx / hidden; + const size_t col = idx % hidden; + size_t act_idx = 0; + size_t linear_idx = 0; + glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); + + float dact = 0.0f; + float dlinear = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(input[act_idx]), + static_cast(input[linear_idx]), param, &dact, &dlinear, + &unscaled); + (void)unscaled; + const float scale = static_cast(act_scales[row]); + const float grad = static_cast(grad_output[idx]) * scale; + grad_input[act_idx] = static_cast(grad * dact); + grad_input[linear_idx] = static_cast(grad * dlinear); + } +} + +template +__global__ void scaled_srelu_backward_kernel(const GradT *grad_output, const InputT *input, + const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden) { + const size_t total = rows * hidden; + Empty empty = {}; + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; + idx += gridDim.x * blockDim.x) { + const size_t row = idx / hidden; + const float scale = static_cast(act_scales[row]); + const float grad = static_cast(grad_output[idx]) * scale; + grad_input[idx] = + static_cast(grad * dsrelu(static_cast(input[idx]), empty)); + } +} + +template +__global__ void scaled_gated_backward_with_scale_grad_kernel( + const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, + const int64_t glu_interleave_size, const ClampedSwiGLUParam param) { + constexpr int kThreads = 256; + __shared__ float smem[kThreads]; + const size_t row = blockIdx.x; + float scale_grad = 0.0f; + + for (size_t col = threadIdx.x; col < hidden; col += blockDim.x) { + const size_t grad_idx = row * hidden + col; + size_t act_idx = 0; + size_t linear_idx = 0; + glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); + + float dact = 0.0f; + float dlinear = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(input[act_idx]), + static_cast(input[linear_idx]), param, &dact, &dlinear, + &unscaled); + const float grad = static_cast(grad_output[grad_idx]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + grad_input[act_idx] = static_cast(scaled_grad * dact); + grad_input[linear_idx] = static_cast(scaled_grad * dlinear); + } + + smem[threadIdx.x] = scale_grad; + __syncthreads(); + for (int offset = kThreads / 2; offset > 0; offset >>= 1) { + if (threadIdx.x < offset) { + smem[threadIdx.x] += smem[threadIdx.x + offset]; + } + __syncthreads(); + } + if (threadIdx.x == 0) { + grad_act_scales[row] = static_cast(smem[0]); + } +} + +template +__global__ void scaled_srelu_backward_with_scale_grad_kernel( + const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden) { + constexpr int kThreads = 256; + __shared__ float smem[kThreads]; + const size_t row = blockIdx.x; + float scale_grad = 0.0f; + Empty empty = {}; + + for (size_t col = threadIdx.x; col < hidden; col += blockDim.x) { + const size_t idx = row * hidden + col; + const float unscaled = srelu(static_cast(input[idx]), empty); + const float grad = static_cast(grad_output[idx]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + const float dact = dsrelu(static_cast(input[idx]), empty); + grad_input[idx] = static_cast(scaled_grad * dact); + } + + smem[threadIdx.x] = scale_grad; + __syncthreads(); + for (int offset = kThreads / 2; offset > 0; offset >>= 1) { + if (threadIdx.x < offset) { + smem[threadIdx.x] += smem[threadIdx.x + offset]; + } + __syncthreads(); + } + if (threadIdx.x == 0) { + grad_act_scales[row] = static_cast(smem[0]); + } +} + +void check_scale_tensor(const Tensor *act_scales, const size_t rows, const char *api_name) { + NVTE_CHECK(act_scales->numel() == rows, api_name, ": act_scales must have one value per row."); +} + +void check_gated_forward_tensors(const Tensor *input, const Tensor *act_scales, + const Tensor *output, const int64_t glu_interleave_size, + const char *api_name, size_t *rows, size_t *hidden) { + const auto input_dims = input->flat_2d_dims(); + const auto output_dims = output->flat_2d_dims(); + NVTE_CHECK(input_dims[0] == output_dims[0], api_name, ": input/output row mismatch."); + NVTE_CHECK(input_dims[1] == output_dims[1] * 2, api_name, + ": gated input last dimension must be twice output last dimension."); + NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); + if (glu_interleave_size > 0) { + NVTE_CHECK(output_dims[1] % static_cast(glu_interleave_size) == 0, api_name, + ": output last dimension must be divisible by glu_interleave_size."); + } + check_scale_tensor(act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = output_dims[1]; +} + +void check_unary_forward_tensors(const Tensor *input, const Tensor *act_scales, + const Tensor *output, const char *api_name, size_t *rows, + size_t *hidden) { + const auto input_dims = input->flat_2d_dims(); + const auto output_dims = output->flat_2d_dims(); + NVTE_CHECK(input_dims[0] == output_dims[0] && input_dims[1] == output_dims[1], api_name, + ": input/output shapes must match."); + check_scale_tensor(act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = output_dims[1]; +} + +void check_grad_scale_tensor(const Tensor *grad_act_scales, const size_t rows, + const char *api_name) { + if (grad_act_scales != nullptr) { + NVTE_CHECK(grad_act_scales->numel() == rows, api_name, + ": grad_act_scales must have one value per row."); + } +} + +void check_gated_backward_tensors(const Tensor *grad_output, const Tensor *input, + const Tensor *act_scales, const Tensor *grad_input, + const Tensor *grad_act_scales, + const int64_t glu_interleave_size, const char *api_name, + size_t *rows, size_t *hidden) { + const auto grad_dims = grad_output->flat_2d_dims(); + const auto input_dims = input->flat_2d_dims(); + const auto grad_input_dims = grad_input->flat_2d_dims(); + NVTE_CHECK(grad_dims[0] == input_dims[0] && input_dims[0] == grad_input_dims[0], api_name, + ": input/grad row mismatch."); + NVTE_CHECK(input_dims[1] == grad_dims[1] * 2 && grad_input_dims[1] == input_dims[1], api_name, + ": gated backward dimensions are inconsistent."); + NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); + if (glu_interleave_size > 0) { + NVTE_CHECK(grad_dims[1] % static_cast(glu_interleave_size) == 0, api_name, + ": grad last dimension must be divisible by glu_interleave_size."); + } + check_scale_tensor(act_scales, input_dims[0], api_name); + check_grad_scale_tensor(grad_act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = grad_dims[1]; +} + +void check_unary_backward_tensors(const Tensor *grad_output, const Tensor *input, + const Tensor *act_scales, const Tensor *grad_input, + const Tensor *grad_act_scales, const char *api_name, + size_t *rows, size_t *hidden) { + const auto grad_dims = grad_output->flat_2d_dims(); + const auto input_dims = input->flat_2d_dims(); + const auto grad_input_dims = grad_input->flat_2d_dims(); + NVTE_CHECK(grad_dims[0] == input_dims[0] && input_dims[0] == grad_input_dims[0], api_name, + ": input/grad row mismatch."); + NVTE_CHECK(grad_dims[1] == input_dims[1] && input_dims[1] == grad_input_dims[1], api_name, + ": unary backward dimensions are inconsistent."); + check_scale_tensor(act_scales, input_dims[0], api_name); + check_grad_scale_tensor(grad_act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = grad_dims[1]; +} + +template +void launch_scaled_gated_forward(const NVTETensor nvte_input, const NVTETensor nvte_act_scales, + NVTETensor nvte_output, const int64_t glu_interleave_size, + const ClampedSwiGLUParam param, cudaStream_t stream, + const char *api_name) { + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *output = convertNVTETensorCheck(nvte_output); + size_t rows = 0; + size_t hidden = 0; + check_gated_forward_tensors(input, act_scales, output, glu_interleave_size, api_name, &rows, + &hidden); + if (rows == 0 || hidden == 0) return; + + constexpr int threads = 256; + const int blocks = static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { + scaled_gated_forward_kernel + <<>>( + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(output->data.dptr), rows, hidden, glu_interleave_size, + param); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor nvte_act_scales, + NVTETensor nvte_output, cudaStream_t stream, + const char *api_name) { + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *output = convertNVTETensorCheck(nvte_output); + size_t rows = 0; + size_t hidden = 0; + check_unary_forward_tensors(input, act_scales, output, api_name, &rows, &hidden); + if (rows == 0 || hidden == 0) return; + + constexpr int threads = 256; + const int blocks = static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { + scaled_srelu_forward_kernel + <<>>( + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(output->data.dptr), rows, hidden); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTETensor nvte_input, + const NVTETensor nvte_act_scales, NVTETensor nvte_grad_input, + NVTETensor nvte_grad_act_scales, + const int64_t glu_interleave_size, + const ClampedSwiGLUParam param, cudaStream_t stream, + const char *api_name) { + const Tensor *grad_output = convertNVTETensorCheck(nvte_grad_output); + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *grad_input = convertNVTETensorCheck(nvte_grad_input); + Tensor *grad_act_scales = + nvte_grad_act_scales == nullptr ? nullptr : convertNVTETensorCheck(nvte_grad_act_scales); + size_t rows = 0; + size_t hidden = 0; + check_gated_backward_tensors(grad_output, input, act_scales, grad_input, grad_act_scales, + glu_interleave_size, api_name, &rows, &hidden); + if (rows == 0 || hidden == 0) return; + + constexpr int threads = 256; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + if (grad_act_scales == nullptr) { + const int blocks = + static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); + scaled_gated_backward_kernel + <<>>( + reinterpret_cast(grad_output->data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(grad_input->data.dptr), rows, hidden, + glu_interleave_size, param); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { + scaled_gated_backward_with_scale_grad_kernel + <<(rows), threads, 0, stream>>>( + reinterpret_cast(grad_output->data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(grad_input->data.dptr), + reinterpret_cast(grad_act_scales->data.dptr), rows, hidden, + glu_interleave_size, param); + }); + } + }); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTETensor nvte_input, + const NVTETensor nvte_act_scales, NVTETensor nvte_grad_input, + NVTETensor nvte_grad_act_scales, cudaStream_t stream, + const char *api_name) { + const Tensor *grad_output = convertNVTETensorCheck(nvte_grad_output); + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *grad_input = convertNVTETensorCheck(nvte_grad_input); + Tensor *grad_act_scales = + nvte_grad_act_scales == nullptr ? nullptr : convertNVTETensorCheck(nvte_grad_act_scales); + size_t rows = 0; + size_t hidden = 0; + check_unary_backward_tensors(grad_output, input, act_scales, grad_input, grad_act_scales, + api_name, &rows, &hidden); + if (rows == 0 || hidden == 0) return; + + constexpr int threads = 256; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + if (grad_act_scales == nullptr) { + const int blocks = + static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); + scaled_srelu_backward_kernel + <<>>( + reinterpret_cast(grad_output->data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(grad_input->data.dptr), rows, hidden); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { + scaled_srelu_backward_with_scale_grad_kernel + <<(rows), threads, 0, stream>>>( + reinterpret_cast(grad_output->data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(grad_input->data.dptr), + reinterpret_cast(grad_act_scales->data.dptr), rows, hidden); + }); + } + }); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace +} // namespace transformer_engine + +void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + int64_t glu_interleave_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_swiglu); + using namespace transformer_engine; + Empty empty = {}; + (void)empty; + ClampedSwiGLUParam param = {}; + launch_scaled_gated_forward( + input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_swiglu"); +} + +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, int64_t glu_interleave_size, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_dswiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {}; + launch_scaled_gated_backward( + grad, input, act_scales, grad_input, grad_act_scales, glu_interleave_size, param, stream, + "nvte_scaled_dswiglu"); +} + +void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales, + NVTETensor output, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_clamped_swiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; + launch_scaled_gated_forward( + input, act_scales, output, glu_interleave_size, param, stream, + "nvte_scaled_clamped_swiglu"); +} + +void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_clamped_dswiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; + launch_scaled_gated_backward( + grad, input, act_scales, grad_input, grad_act_scales, glu_interleave_size, param, stream, + "nvte_scaled_clamped_dswiglu"); +} + +void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_srelu); + using namespace transformer_engine; + launch_scaled_srelu_forward(input, act_scales, output, stream, "nvte_scaled_srelu"); +} + +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_dsrelu); + using namespace transformer_engine; + launch_scaled_srelu_backward(grad, input, act_scales, grad_input, grad_act_scales, stream, + "nvte_scaled_dsrelu"); +} diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 4ed083740d..f1485057ec 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -368,6 +368,41 @@ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, void nvte_clamped_swiglu_v2(const NVTETensor input, NVTETensor output, float limit, float alpha, float glu_linear_offset, cudaStream_t stream); +/*! \brief Computes ScaledSwiGLU without materializing GLU deinterleave. + * + * Computes output = SwiGLU(input) * act_scales[:, None]. + * If glu_interleave_size > 0, input is interpreted as interleaved + * [activation_block, linear_block] chunks of that size. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] output Output tensor of shape [N, H]. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + int64_t glu_interleave_size, cudaStream_t stream); + +/*! \brief Computes ScaledClampedSwiGLU without materializing GLU deinterleave. + * + * Computes output = ClampedSwiGLU(input) * act_scales[:, None]. + * This uses the same clamping, alpha, and linear-offset semantics as + * nvte_clamped_swiglu_v2. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] output Output tensor of shape [N, H]. + * \param[in] limit Clipping limit. + * \param[in] alpha Activation sigmoid alpha. + * \param[in] glu_linear_offset Offset added to linear component after clamping. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales, + NVTETensor output, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream); + /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -473,6 +508,46 @@ void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTE float limit, float alpha, float glu_linear_offset, cudaStream_t stream); +/*! \brief Computes ScaledSwiGLU backward without materializing GLU deinterleave. + * + * The optional grad_act_scales tensor may be null. When present, it receives + * sum(dY * SwiGLU(input), dim=-1). + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] grad_input Outgoing gradient of shape [N, H * 2]. + * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, int64_t glu_interleave_size, + cudaStream_t stream); + +/*! \brief Computes ScaledClampedSwiGLU backward without materializing GLU deinterleave. + * + * The optional grad_act_scales tensor may be null. When present, it receives + * sum(dY * ClampedSwiGLU(input), dim=-1). + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] grad_input Outgoing gradient of shape [N, H * 2]. + * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. + * \param[in] limit Clipping limit. + * \param[in] alpha Activation sigmoid alpha. + * \param[in] glu_linear_offset Offset added to linear component after clamping. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream); + /*! \brief Computes the gated ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -509,6 +584,34 @@ void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes ScaledSReLU. + * + * Computes output = SReLU(input) * act_scales[:, None]. + * + * \param[in] input Input tensor for activation. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + cudaStream_t stream); + +/*! \brief Computes ScaledSReLU backward. + * + * The optional grad_act_scales tensor may be null. When present, it receives + * sum(dY * SReLU(input), dim=-1). + * + * \param[in] grad Incoming gradient. + * \param[in] input Forward input tensor. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] grad_input Outgoing input gradient. + * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif From 953c46903c47ac65e8e99cd2520761c2ae593240 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Jun 2026 17:34:59 -0700 Subject: [PATCH 2/8] vectorized loading improvement Signed-off-by: Zhongbo Zhu --- .../common/activation/scaled_activation.cu | 652 +++++++++++++----- 1 file changed, 463 insertions(+), 189 deletions(-) diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu index 176253edc2..7053e06241 100644 --- a/transformer_engine/common/activation/scaled_activation.cu +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -12,16 +12,47 @@ * * # | Kernel | Activation | Dir | grad_act_scales | Launch * ---+-----------------------------------------------+------------------------+-----+-----------------+-------------------- - * 1 | scaled_gated_forward_kernel | SwiGLU / ClampedSwiGLU | fwd | -- | flat element grid - * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | flat element grid - * 3 | scaled_gated_backward_kernel | SwiGLU / ClampedSwiGLU | bwd | no | flat element grid - * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | flat element grid - * 5 | scaled_gated_backward_with_scale_grad_kernel | SwiGLU / ClampedSwiGLU | bwd | yes | one block per row - * 6 | scaled_srelu_backward_with_scale_grad_kernel | SReLU | bwd | yes | one block per row + * 1 | scaled_gated_forward_kernel | SwiGLU / ClampedSwiGLU | fwd | -- | vectorized row segments + * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | vectorized flat grid + * 3 | scaled_gated_backward_kernel | SwiGLU / ClampedSwiGLU | bwd | no | vectorized row segments + * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | vectorized flat grid + * 5 | scaled_gated_backward_with_scale_grad_kernel | SwiGLU / ClampedSwiGLU | bwd | yes | vectorized, one block per row + * 6 | scaled_srelu_backward_with_scale_grad_kernel | SReLU | bwd | yes | vectorized, one block per row * * The "with scale grad" variants compute grad_act_scales[row] = sum_j dY * unscaled, * a per-row reduction that requires the one-block-per-row launch; when * grad_act_scales is null the cheaper flat element-wise grid is used instead. + * + * Vectorization model: + * + * Gated activations consume two FC1 streams per row: an activation stream and a + * gate stream. With no GLU interleave, the row is laid out as: + * + * [ act[0:H] | gate[0:H] ] + * + * With GLU interleave, e.g. interleave=32, the row is laid out as independent + * act/gate segments: + * + * [ act[0:32] | gate[0:32] | act[32:64] | gate[32:64] | ... ] + * + * Vector loads: + * + * interleave=0: + * input [ act0 | act1 | ... | actN | gate0 | gate1 | ... | gateN ] + * | | + * v v + * load act vector i gate vector i + * store output vector i = activation(act vector i) * gate vector i * scale[row] + * + * interleave=32: + * input [ act0 | gate0 | act1 | gate1 | ... | actN | gateN ] + * | | | | + * v v v v + * load act0 gate0 act1 gate1 + * store output vector i = activation(act vector i) * gate vector i * scale[row] + * + * Only fully aligned segments use vector loads. Everything else uses the same + * kernels with nvec=1, i.e. regular elementwise loads/stores. */ #include @@ -30,6 +61,7 @@ #include "../common.h" #include "../util/math.h" +#include "../util/vectorized_pointwise.h" namespace transformer_engine { namespace { @@ -40,172 +72,306 @@ enum class ScaledActivation { kSReLU, }; -__device__ __forceinline__ void glu_input_indices(const size_t row, const size_t col, - const size_t hidden, - const int64_t glu_interleave_size, - size_t *act_idx, size_t *linear_idx) { - if (glu_interleave_size > 0) { - const size_t interleave = static_cast(glu_interleave_size); - const size_t block = col / interleave; - const size_t lane = col % interleave; - const size_t base = row * hidden * 2 + block * interleave * 2 + lane; - *act_idx = base; - *linear_idx = base + interleave; - } else { - const size_t base = row * hidden * 2; - *act_idx = base + col; - *linear_idx = base + hidden + col; - } -} - template -__device__ __forceinline__ float gated_forward_value(const float act_in, const float linear_in, +__device__ __forceinline__ float gated_forward_value(const float act_in, const float gate_in, const ClampedSwiGLUParam ¶m) { if constexpr (Act == ScaledActivation::kSwiGLU) { Empty empty = {}; - return silu(act_in, empty) * linear_in; + return silu(act_in, empty) * gate_in; } else { - const float linear = - fminf(fmaxf(-param.limit, linear_in), param.limit) + param.glu_linear_offset; - return clamped_silu(act_in, param) * linear; + const float gate = fminf(fmaxf(-param.limit, gate_in), param.limit) + param.glu_linear_offset; + return clamped_silu(act_in, param) * gate; } } template -__device__ __forceinline__ void gated_backward_values(const float act_in, const float linear_in, - const ClampedSwiGLUParam ¶m, - float *dact, float *dlinear, +__device__ __forceinline__ void gated_backward_values(const float act_in, const float gate_in, + const ClampedSwiGLUParam ¶m, float *dact, + float *dgate, float *unscaled) { if constexpr (Act == ScaledActivation::kSwiGLU) { Empty empty = {}; const float act = silu(act_in, empty); - *unscaled = act * linear_in; - *dact = dsilu(act_in, empty) * linear_in; - *dlinear = act; + *unscaled = act * gate_in; + *dact = dsilu(act_in, empty) * gate_in; + *dgate = act; } else { - const bool dlinear_mask = linear_in <= param.limit && linear_in >= -param.limit; - const float linear = - fminf(fmaxf(-param.limit, linear_in), param.limit) + param.glu_linear_offset; + const bool dgate_mask = gate_in <= param.limit && gate_in >= -param.limit; + const float gate = fminf(fmaxf(-param.limit, gate_in), param.limit) + param.glu_linear_offset; const float act = clamped_silu(act_in, param); - *unscaled = act * linear; - *dact = clamped_dsilu(act_in, param) * linear; - *dlinear = dlinear_mask ? act : 0.0f; + *unscaled = act * gate; + *dact = clamped_dsilu(act_in, param) * gate; + *dgate = dgate_mask ? act : 0.0f; + } +} + +constexpr int kThreads = unary_kernel_threads; + +template +constexpr int vector_width() { + return 32 / static_cast(sizeof(T)); +} + +inline int launch_blocks(const size_t work_items) { + return static_cast( + std::min(DIVUP(work_items, static_cast(kThreads)), 65535)); +} + +template +Alignment row_vector_alignment(const size_t lead_dim, const int nvec, const Ptrs... ptrs) { + if (nvec == 1) { + return Alignment::SAME_ALIGNED; + } + // GLU interleave is handled as independent row-local segments. Keep the scalar + // fallback for odd segment widths or unaligned pointers so vector stores never + // cross from an activation segment into its paired gate segment. + if (lead_dim % static_cast(nvec) != 0) { + return Alignment::DIFFERENT; + } + const auto align = CheckAlignment(lead_dim, nvec, ptrs...); + return align == Alignment::SAME_ALIGNED ? Alignment::SAME_ALIGNED : Alignment::DIFFERENT; +} + +template +__device__ __forceinline__ bool vector_lane_index(const size_t vector_idx, const int lane, + const int alignment, const size_t length, + size_t *index) { + size_t idx = vector_idx * static_cast(nvec) + static_cast(lane); + if constexpr (!aligned) { + if (idx < static_cast(alignment)) { + return false; + } + idx -= static_cast(alignment); + } + if (idx >= length) { + return false; } + *index = idx; + return true; } -template +template __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *act_scales, OutputT *output, const size_t rows, - const size_t hidden, - const int64_t glu_interleave_size, + const size_t hidden, const size_t segment_size, + const size_t num_segments, + const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { - const size_t total = rows * hidden; - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; - idx += gridDim.x * blockDim.x) { - const size_t row = idx / hidden; - const size_t col = idx % hidden; - size_t act_idx = 0; - size_t linear_idx = 0; - glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); - - const float unscaled = gated_forward_value(static_cast(input[act_idx]), - static_cast(input[linear_idx]), param); + const size_t total_vectors = rows * num_segments * num_vectors_per_segment; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_segment; + const size_t segment = (tid / num_vectors_per_segment) % num_segments; + const size_t row = tid / (num_vectors_per_segment * num_segments); + const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; + const size_t output_segment_offset = row * hidden + segment * segment_size; + + VectorizedLoader act_loader(input + input_segment_offset, + segment_size); + VectorizedLoader gate_loader( + input + input_segment_offset + segment_size, segment_size); + VectorizedStorer output_storer(output + output_segment_offset, + segment_size); + if (vector_idx >= act_loader.num_aligned_elements()) { + continue; + } + + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); const float scale = static_cast(act_scales[row]); - output[idx] = static_cast(unscaled * scale); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, act_loader.alignment(), + segment_size, &col)) { + const float unscaled = + gated_forward_value(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param); + output_storer.separate()[lane] = static_cast(unscaled * scale); + } + } + output_storer.store(vector_idx, segment_size); } } -template +template __global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *act_scales, - OutputT *output, const size_t rows, - const size_t hidden) { - const size_t total = rows * hidden; + OutputT *output, const size_t total, + const size_t hidden, + const size_t num_vectors) { Empty empty = {}; - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; - idx += gridDim.x * blockDim.x) { - const size_t row = idx / hidden; - const float unscaled = srelu(static_cast(input[idx]), empty); - const float scale = static_cast(act_scales[row]); - output[idx] = static_cast(unscaled * scale); + VectorizedLoader input_loader(input, total); + VectorizedStorer output_storer(output, total); + for (size_t vector_idx = blockIdx.x * blockDim.x + threadIdx.x; vector_idx < num_vectors; + vector_idx += gridDim.x * blockDim.x) { + if (vector_idx >= input_loader.num_aligned_elements()) { + continue; + } + input_loader.load(vector_idx, total); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t idx = 0; + if (vector_lane_index(vector_idx, lane, input_loader.alignment(), total, + &idx)) { + const size_t row = idx / hidden; + const float unscaled = srelu(static_cast(input_loader.separate()[lane]), + empty); + const float scale = static_cast(act_scales[row]); + output_storer.separate()[lane] = static_cast(unscaled * scale); + } + } + output_storer.store(vector_idx, total); } } -template -__global__ void scaled_gated_backward_kernel(const GradT *grad_output, const InputT *input, - const ScaleT *act_scales, OutputT *grad_input, - const size_t rows, const size_t hidden, - const int64_t glu_interleave_size, - const ClampedSwiGLUParam param) { - const size_t total = rows * hidden; - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; - idx += gridDim.x * blockDim.x) { - const size_t row = idx / hidden; - const size_t col = idx % hidden; - size_t act_idx = 0; - size_t linear_idx = 0; - glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); - - float dact = 0.0f; - float dlinear = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(input[act_idx]), - static_cast(input[linear_idx]), param, &dact, &dlinear, - &unscaled); - (void)unscaled; +template +__global__ void scaled_gated_backward_kernel( + const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, + const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { + const size_t total_vectors = rows * num_segments * num_vectors_per_segment; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_segment; + const size_t segment = (tid / num_vectors_per_segment) % num_segments; + const size_t row = tid / (num_vectors_per_segment * num_segments); + const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; + const size_t output_segment_offset = row * hidden + segment * segment_size; + + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, + segment_size); + VectorizedLoader gate_loader( + input + input_segment_offset + segment_size, segment_size); + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( + grad_input + input_segment_offset + segment_size, segment_size); + if (vector_idx >= act_loader.num_aligned_elements()) { + continue; + } + + grad_loader.load(vector_idx, segment_size); + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); const float scale = static_cast(act_scales[row]); - const float grad = static_cast(grad_output[idx]) * scale; - grad_input[act_idx] = static_cast(grad * dact); - grad_input[linear_idx] = static_cast(grad * dlinear); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, act_loader.alignment(), + segment_size, &col)) { + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + (void)unscaled; + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + act_storer.separate()[lane] = static_cast(grad * dact); + gate_storer.separate()[lane] = static_cast(grad * dgate); + } + } + act_storer.store(vector_idx, segment_size); + gate_storer.store(vector_idx, segment_size); } } -template +template __global__ void scaled_srelu_backward_kernel(const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - const size_t rows, const size_t hidden) { - const size_t total = rows * hidden; + const size_t total, const size_t hidden, + const size_t num_vectors) { Empty empty = {}; - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; - idx += gridDim.x * blockDim.x) { - const size_t row = idx / hidden; - const float scale = static_cast(act_scales[row]); - const float grad = static_cast(grad_output[idx]) * scale; - grad_input[idx] = - static_cast(grad * dsrelu(static_cast(input[idx]), empty)); + VectorizedLoader grad_loader(grad_output, total); + VectorizedLoader input_loader(input, total); + VectorizedStorer grad_input_storer(grad_input, total); + for (size_t vector_idx = blockIdx.x * blockDim.x + threadIdx.x; vector_idx < num_vectors; + vector_idx += gridDim.x * blockDim.x) { + if (vector_idx >= input_loader.num_aligned_elements()) { + continue; + } + grad_loader.load(vector_idx, total); + input_loader.load(vector_idx, total); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t idx = 0; + if (vector_lane_index(vector_idx, lane, input_loader.alignment(), total, + &idx)) { + const size_t row = idx / hidden; + const float scale = static_cast(act_scales[row]); + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + grad_input_storer.separate()[lane] = + static_cast(grad * dsrelu( + static_cast(input_loader.separate()[lane]), + empty)); + } + } + grad_input_storer.store(vector_idx, total); } } -template +template __global__ void scaled_gated_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, - const int64_t glu_interleave_size, const ClampedSwiGLUParam param) { - constexpr int kThreads = 256; + const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, + const ClampedSwiGLUParam param) { __shared__ float smem[kThreads]; const size_t row = blockIdx.x; + (void)rows; float scale_grad = 0.0f; - for (size_t col = threadIdx.x; col < hidden; col += blockDim.x) { - const size_t grad_idx = row * hidden + col; - size_t act_idx = 0; - size_t linear_idx = 0; - glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); - - float dact = 0.0f; - float dlinear = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(input[act_idx]), - static_cast(input[linear_idx]), param, &dact, &dlinear, - &unscaled); - const float grad = static_cast(grad_output[grad_idx]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - grad_input[act_idx] = static_cast(scaled_grad * dact); - grad_input[linear_idx] = static_cast(scaled_grad * dlinear); + for (size_t segment = 0; segment < num_segments; ++segment) { + const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; + const size_t output_segment_offset = row * hidden + segment * segment_size; + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, + segment_size); + VectorizedLoader gate_loader( + input + input_segment_offset + segment_size, segment_size); + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( + grad_input + input_segment_offset + segment_size, segment_size); + + for (size_t vector_idx = threadIdx.x; vector_idx < num_vectors_per_segment; + vector_idx += blockDim.x) { + if (vector_idx >= act_loader.num_aligned_elements()) { + continue; + } + grad_loader.load(vector_idx, segment_size); + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, act_loader.alignment(), + segment_size, &col)) { + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + act_storer.separate()[lane] = static_cast(scaled_grad * dact); + gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); + } + } + act_storer.store(vector_idx, segment_size); + gate_storer.store(vector_idx, segment_size); + } } smem[threadIdx.x] = scale_grad; @@ -221,26 +387,46 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( } } -template +template __global__ void scaled_srelu_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - GradScaleT *grad_act_scales, const size_t rows, const size_t hidden) { - constexpr int kThreads = 256; + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, + const size_t num_vectors_per_row) { __shared__ float smem[kThreads]; const size_t row = blockIdx.x; + (void)rows; float scale_grad = 0.0f; Empty empty = {}; - for (size_t col = threadIdx.x; col < hidden; col += blockDim.x) { - const size_t idx = row * hidden + col; - const float unscaled = srelu(static_cast(input[idx]), empty); - const float grad = static_cast(grad_output[idx]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - const float dact = dsrelu(static_cast(input[idx]), empty); - grad_input[idx] = static_cast(scaled_grad * dact); + VectorizedLoader grad_loader(grad_output + row * hidden, hidden); + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); + for (size_t vector_idx = threadIdx.x; vector_idx < num_vectors_per_row; + vector_idx += blockDim.x) { + if (vector_idx >= input_loader.num_aligned_elements()) { + continue; + } + grad_loader.load(vector_idx, hidden); + input_loader.load(vector_idx, hidden); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, input_loader.alignment(), hidden, + &col)) { + const float unscaled = + srelu(static_cast(input_loader.separate()[lane]), empty); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + const float dact = + dsrelu(static_cast(input_loader.separate()[lane]), empty); + grad_input_storer.separate()[lane] = static_cast(scaled_grad * dact); + } + } + grad_input_storer.store(vector_idx, hidden); } smem[threadIdx.x] = scale_grad; @@ -270,6 +456,8 @@ void check_gated_forward_tensors(const Tensor *input, const Tensor *act_scales, ": gated input last dimension must be twice output last dimension."); NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); if (glu_interleave_size > 0) { + NVTE_CHECK(glu_interleave_size % 32 == 0, api_name, + ": nonzero glu_interleave_size must be a multiple of 32."); NVTE_CHECK(output_dims[1] % static_cast(glu_interleave_size) == 0, api_name, ": output last dimension must be divisible by glu_interleave_size."); } @@ -312,6 +500,8 @@ void check_gated_backward_tensors(const Tensor *grad_output, const Tensor *input ": gated backward dimensions are inconsistent."); NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); if (glu_interleave_size > 0) { + NVTE_CHECK(glu_interleave_size % 32 == 0, api_name, + ": nonzero glu_interleave_size must be a multiple of 32."); NVTE_CHECK(grad_dims[1] % static_cast(glu_interleave_size) == 0, api_name, ": grad last dimension must be divisible by glu_interleave_size."); } @@ -352,17 +542,34 @@ void launch_scaled_gated_forward(const NVTETensor nvte_input, const NVTETensor n &hidden); if (rows == 0 || hidden == 0) return; - constexpr int threads = 256; - const int blocks = static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { - scaled_gated_forward_kernel - <<>>( - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(output->data.dptr), rows, hidden, glu_interleave_size, - param); + constexpr int nvec = + sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto output_ptr = reinterpret_cast(output->data.dptr); + const size_t segment_size = + glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; + const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; + const auto align = + row_vector_alignment(segment_size, nvec, input_ptr, input_ptr + segment_size, + output_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) + : segment_size; + const int blocks = launch_blocks(rows * num_segments * num_vectors); + if (use_vector) { + scaled_gated_forward_kernel + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, + segment_size, num_segments, num_vectors, param); + } else { + scaled_gated_forward_kernel<1, true, InputT, ScaleT, OutputT, Act> + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, + segment_size, num_segments, segment_size, param); + } }); }); }); @@ -380,16 +587,29 @@ void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor n check_unary_forward_tensors(input, act_scales, output, api_name, &rows, &hidden); if (rows == 0 || hidden == 0) return; - constexpr int threads = 256; - const int blocks = static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { - scaled_srelu_forward_kernel - <<>>( - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(output->data.dptr), rows, hidden); + constexpr int nvec = + sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto output_ptr = reinterpret_cast(output->data.dptr); + const size_t total = rows * hidden; + const auto align = CheckAlignment(total, nvec, input_ptr, output_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, total, nvec, sizeof(InputT)) : total; + const int blocks = launch_blocks(num_vectors); + if (use_vector) { + scaled_srelu_forward_kernel + <<>>(input_ptr, scale_ptr, output_ptr, total, hidden, + num_vectors); + } else { + scaled_srelu_forward_kernel<1, true, InputT, ScaleT, OutputT> + <<>>(input_ptr, scale_ptr, output_ptr, total, hidden, + total); + } }); }); }); @@ -415,32 +635,58 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET glu_interleave_size, api_name, &rows, &hidden); if (rows == 0 || hidden == 0) return; - constexpr int threads = 256; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && + sizeof(InputT) == sizeof(OutputT) + ? vector_width() + : 1; + const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); + const size_t segment_size = + glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; + const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; + const auto align = row_vector_alignment( + segment_size, nvec, grad_ptr, input_ptr, input_ptr + segment_size, grad_input_ptr, + grad_input_ptr + segment_size); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) + : segment_size; if (grad_act_scales == nullptr) { - const int blocks = - static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); - scaled_gated_backward_kernel - <<>>( - reinterpret_cast(grad_output->data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(grad_input->data.dptr), rows, hidden, - glu_interleave_size, param); + const int blocks = launch_blocks(rows * num_segments * num_vectors); + if (use_vector) { + scaled_gated_backward_kernel + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + rows, hidden, segment_size, num_segments, + num_vectors, param); + } else { + scaled_gated_backward_kernel<1, true, GradT, InputT, ScaleT, OutputT, Act> + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + rows, hidden, segment_size, num_segments, + segment_size, param); + } } else { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { - scaled_gated_backward_with_scale_grad_kernel - <<(rows), threads, 0, stream>>>( - reinterpret_cast(grad_output->data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(grad_input->data.dptr), - reinterpret_cast(grad_act_scales->data.dptr), rows, hidden, - glu_interleave_size, param); + auto grad_act_scales_ptr = + reinterpret_cast(grad_act_scales->data.dptr); + if (use_vector) { + scaled_gated_backward_with_scale_grad_kernel< + nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + <<(rows), kThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, segment_size, num_segments, num_vectors, param); + } else { + scaled_gated_backward_with_scale_grad_kernel< + 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + <<(rows), kThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, segment_size, num_segments, segment_size, param); + } }); } }); @@ -466,30 +712,58 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET api_name, &rows, &hidden); if (rows == 0 || hidden == 0) return; - constexpr int threads = 256; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && + sizeof(InputT) == sizeof(OutputT) + ? vector_width() + : 1; + const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); if (grad_act_scales == nullptr) { - const int blocks = - static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); - scaled_srelu_backward_kernel - <<>>( - reinterpret_cast(grad_output->data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(grad_input->data.dptr), rows, hidden); + const size_t total = rows * hidden; + const auto align = CheckAlignment(total, nvec, grad_ptr, input_ptr, grad_input_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, total, nvec, sizeof(InputT)) + : total; + const int blocks = launch_blocks(num_vectors); + if (use_vector) { + scaled_srelu_backward_kernel + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + total, hidden, num_vectors); + } else { + scaled_srelu_backward_kernel<1, true, GradT, InputT, ScaleT, OutputT> + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + total, hidden, total); + } } else { + const auto align = row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, + grad_input_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) + : hidden; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { - scaled_srelu_backward_with_scale_grad_kernel - <<(rows), threads, 0, stream>>>( - reinterpret_cast(grad_output->data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(grad_input->data.dptr), - reinterpret_cast(grad_act_scales->data.dptr), rows, hidden); + auto grad_act_scales_ptr = + reinterpret_cast(grad_act_scales->data.dptr); + if (use_vector) { + scaled_srelu_backward_with_scale_grad_kernel< + nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT> + <<(rows), kThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, num_vectors); + } else { + scaled_srelu_backward_with_scale_grad_kernel< + 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT> + <<(rows), kThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, hidden); + } }); } }); From e3ae293ef8a8996dfee03bc110bf942e20287427 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Jun 2026 19:56:53 -0700 Subject: [PATCH 3/8] fix bug for backward kernel Signed-off-by: Zhongbo Zhu --- .../common/activation/scaled_activation.cu | 78 ++++++++++--------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu index 7053e06241..ad854f3e4e 100644 --- a/transformer_engine/common/activation/scaled_activation.cu +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -106,6 +106,7 @@ __device__ __forceinline__ void gated_backward_values(const float act_in, const } constexpr int kThreads = unary_kernel_threads; +constexpr int kReductionThreads = 256; template constexpr int vector_width() { @@ -322,12 +323,18 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { - __shared__ float smem[kThreads]; + __shared__ float smem[kReductionThreads]; const size_t row = blockIdx.x; (void)rows; float scale_grad = 0.0f; - for (size_t segment = 0; segment < num_segments; ++segment) { + // Flatten (segment, vector) so interleave=32 distributes all row work across + // the block instead of using only a few threads per small act/gate segment. + const size_t row_vectors = num_segments * num_vectors_per_segment; + for (size_t row_vector_idx = threadIdx.x; row_vector_idx < row_vectors; + row_vector_idx += blockDim.x) { + const size_t segment = row_vector_idx / num_vectors_per_segment; + const size_t vector_idx = row_vector_idx % num_vectors_per_segment; const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; const size_t output_segment_offset = row * hidden + segment * segment_size; VectorizedLoader grad_loader(grad_output + output_segment_offset, @@ -341,42 +348,39 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( VectorizedStorer gate_storer( grad_input + input_segment_offset + segment_size, segment_size); - for (size_t vector_idx = threadIdx.x; vector_idx < num_vectors_per_segment; - vector_idx += blockDim.x) { - if (vector_idx >= act_loader.num_aligned_elements()) { - continue; - } - grad_loader.load(vector_idx, segment_size); - act_loader.load(vector_idx, segment_size); - gate_loader.load(vector_idx, segment_size); + if (vector_idx >= act_loader.num_aligned_elements()) { + continue; + } + grad_loader.load(vector_idx, segment_size); + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); #pragma unroll - for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, act_loader.alignment(), - segment_size, &col)) { - float dact = 0.0f; - float dgate = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(act_loader.separate()[lane]), - static_cast(gate_loader.separate()[lane]), param, &dact, - &dgate, &unscaled); - const float grad = static_cast(grad_loader.separate()[lane]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - act_storer.separate()[lane] = static_cast(scaled_grad * dact); - gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); - } + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, act_loader.alignment(), segment_size, + &col)) { + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + act_storer.separate()[lane] = static_cast(scaled_grad * dact); + gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); } - act_storer.store(vector_idx, segment_size); - gate_storer.store(vector_idx, segment_size); } + act_storer.store(vector_idx, segment_size); + gate_storer.store(vector_idx, segment_size); } smem[threadIdx.x] = scale_grad; __syncthreads(); - for (int offset = kThreads / 2; offset > 0; offset >>= 1) { + for (int offset = kReductionThreads / 2; offset > 0; offset >>= 1) { if (threadIdx.x < offset) { smem[threadIdx.x] += smem[threadIdx.x + offset]; } @@ -393,7 +397,7 @@ __global__ void scaled_srelu_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t num_vectors_per_row) { - __shared__ float smem[kThreads]; + __shared__ float smem[kReductionThreads]; const size_t row = blockIdx.x; (void)rows; float scale_grad = 0.0f; @@ -431,7 +435,7 @@ __global__ void scaled_srelu_backward_with_scale_grad_kernel( smem[threadIdx.x] = scale_grad; __syncthreads(); - for (int offset = kThreads / 2; offset > 0; offset >>= 1) { + for (int offset = kReductionThreads / 2; offset > 0; offset >>= 1) { if (threadIdx.x < offset) { smem[threadIdx.x] += smem[threadIdx.x + offset]; } @@ -677,13 +681,13 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET if (use_vector) { scaled_gated_backward_with_scale_grad_kernel< nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> - <<(rows), kThreads, 0, stream>>>( + <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { scaled_gated_backward_with_scale_grad_kernel< 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> - <<(rows), kThreads, 0, stream>>>( + <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, segment_size, param); } @@ -754,13 +758,13 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET if (use_vector) { scaled_srelu_backward_with_scale_grad_kernel< nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT> - <<(rows), kThreads, 0, stream>>>( + <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, num_vectors); } else { scaled_srelu_backward_with_scale_grad_kernel< 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT> - <<(rows), kThreads, 0, stream>>>( + <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, hidden); } From 84cbdec2e9088082abe99184347433c87178011d Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Jun 2026 22:14:15 -0700 Subject: [PATCH 4/8] optimize Signed-off-by: Zhongbo Zhu --- tests/cpp/operator/test_scaled_activation.cu | 19 +- .../common/activation/scaled_activation.cu | 401 ++++++++---------- 2 files changed, 181 insertions(+), 239 deletions(-) diff --git a/tests/cpp/operator/test_scaled_activation.cu b/tests/cpp/operator/test_scaled_activation.cu index 1cb630a0bc..80641e3c5f 100644 --- a/tests/cpp/operator/test_scaled_activation.cu +++ b/tests/cpp/operator/test_scaled_activation.cu @@ -295,12 +295,9 @@ TEST_P(ScaledActivationTest, ForwardBackward) { // for gated activations with hidden % 32 == 0; SReLU skips != 0. // 6. compute_grad_scales : whether the backward also reduces grad_act_scales. -// Regular shapes: hidden is a multiple of 32, so the interleaved (32) layout is exercised -// alongside the contiguous (0) layout. -// Regular shapes (hidden % 32 == 0) and weird/irregular shapes (tiny, prime, non-32-aligned) -// share one instantiation. Interleave is swept over {0, 32}; invalid combinations -- SReLU with -// any nonzero interleave, or a gated activation whose hidden is not divisible by the interleave -- -// are skipped at runtime by the GTEST_SKIP guards in the test body. +// Interleave is swept over {0, 32}; invalid combinations -- SReLU with any nonzero interleave, or +// a gated activation whose hidden is not divisible by the interleave -- are skipped at runtime by +// the GTEST_SKIP guards in the test body. INSTANTIATE_TEST_SUITE_P( OperatorTest_ScaledActivation, ScaledActivationTest, ::testing::Combine( @@ -309,20 +306,14 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16), // data dtype ::testing::Values(DType::kFloat32, DType::kBFloat16), // scale dtype ::testing::Values(std::pair{17, 64}, // odd rows, aligned hidden - std::pair{8, 96}, // 96 = 3 * 32 std::pair{32, 32}, // minimal aligned square std::pair{128, 128}, // square - std::pair{64, 256}, // wide hidden std::pair{256, 64}, // many rows, narrow hidden - std::pair{128, 512}, // FFN-ish width + std::pair{1024, 2048}, // large FFN-ish width std::pair{1, 1}, // single element std::pair{1, 96}, // single row std::pair{96, 1}, // single hidden column - std::pair{3, 7}, // tiny primes - std::pair{13, 100}, // non-power-of-two - std::pair{7, 257}, // prime, odd hidden - std::pair{33, 65}, // odd dims - std::pair{129, 31}), // odd rows, hidden < 32 + std::pair{13, 100}), // non-power-of-two ::testing::Values(0, 32), // contiguous + interleaved ::testing::Values(false, true)), // grad_act_scales off / on test_name_generator); diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu index ad854f3e4e..d056966281 100644 --- a/transformer_engine/common/activation/scaled_activation.cu +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -13,9 +13,9 @@ * # | Kernel | Activation | Dir | grad_act_scales | Launch * ---+-----------------------------------------------+------------------------+-----+-----------------+-------------------- * 1 | scaled_gated_forward_kernel | SwiGLU / ClampedSwiGLU | fwd | -- | vectorized row segments - * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | vectorized flat grid + * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | vectorized row grid * 3 | scaled_gated_backward_kernel | SwiGLU / ClampedSwiGLU | bwd | no | vectorized row segments - * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | vectorized flat grid + * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | vectorized row grid * 5 | scaled_gated_backward_with_scale_grad_kernel | SwiGLU / ClampedSwiGLU | bwd | yes | vectorized, one block per row * 6 | scaled_srelu_backward_with_scale_grad_kernel | SReLU | bwd | yes | vectorized, one block per row * @@ -72,6 +72,10 @@ enum class ScaledActivation { kSReLU, }; +__device__ __forceinline__ float sigmoid_from_float(const float x) { + return 1.0f / (1.0f + expf(-x)); +} + template __device__ __forceinline__ float gated_forward_value(const float act_in, const float gate_in, const ClampedSwiGLUParam ¶m) { @@ -90,23 +94,55 @@ __device__ __forceinline__ void gated_backward_values(const float act_in, const float *dgate, float *unscaled) { if constexpr (Act == ScaledActivation::kSwiGLU) { - Empty empty = {}; - const float act = silu(act_in, empty); + const float sigmoid = sigmoid_from_float(act_in); + const float act = act_in * sigmoid; + const float dact_base = sigmoid + act_in * sigmoid * (1.0f - sigmoid); *unscaled = act * gate_in; - *dact = dsilu(act_in, empty) * gate_in; + *dact = dact_base * gate_in; *dgate = act; } else { const bool dgate_mask = gate_in <= param.limit && gate_in >= -param.limit; const float gate = fminf(fmaxf(-param.limit, gate_in), param.limit) + param.glu_linear_offset; - const float act = clamped_silu(act_in, param); + const bool dact_mask = act_in <= param.limit; + const float clamped_act_in = fminf(act_in, param.limit); + const float sigmoid = sigmoid_from_float(param.alpha * clamped_act_in); + const float act = clamped_act_in * sigmoid; + const float dact_base = + dact_mask ? sigmoid + param.alpha * clamped_act_in * sigmoid * (1.0f - sigmoid) : 0.0f; *unscaled = act * gate; - *dact = clamped_dsilu(act_in, param) * gate; + *dact = dact_base * gate; *dgate = dgate_mask ? act : 0.0f; } } constexpr int kThreads = unary_kernel_threads; constexpr int kReductionThreads = 256; +constexpr int kReductionWarps = kReductionThreads / THREADS_PER_WARP; + +__device__ __forceinline__ float warp_reduce_sum(float value) { +#pragma unroll + for (int offset = THREADS_PER_WARP / 2; offset > 0; offset >>= 1) { + value += __shfl_down_sync(0xffffffff, value, offset); + } + return value; +} + +__device__ __forceinline__ float block_reduce_sum(float value, float *smem) { + const int lane = threadIdx.x % THREADS_PER_WARP; + const int warp = threadIdx.x / THREADS_PER_WARP; + + value = warp_reduce_sum(value); + if (lane == 0) { + smem[warp] = value; + } + __syncthreads(); + + value = threadIdx.x < kReductionWarps ? smem[lane] : 0.0f; + if (warp == 0) { + value = warp_reduce_sum(value); + } + return value; +} template constexpr int vector_width() { @@ -133,26 +169,7 @@ Alignment row_vector_alignment(const size_t lead_dim, const int nvec, const Ptrs return align == Alignment::SAME_ALIGNED ? Alignment::SAME_ALIGNED : Alignment::DIFFERENT; } -template -__device__ __forceinline__ bool vector_lane_index(const size_t vector_idx, const int lane, - const int alignment, const size_t length, - size_t *index) { - size_t idx = vector_idx * static_cast(nvec) + static_cast(lane); - if constexpr (!aligned) { - if (idx < static_cast(alignment)) { - return false; - } - idx -= static_cast(alignment); - } - if (idx >= length) { - return false; - } - *index = idx; - return true; -} - -template +template __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *act_scales, OutputT *output, const size_t rows, const size_t hidden, const size_t segment_size, @@ -168,66 +185,52 @@ __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *a const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; const size_t output_segment_offset = row * hidden + segment * segment_size; - VectorizedLoader act_loader(input + input_segment_offset, - segment_size); - VectorizedLoader gate_loader( + VectorizedLoader act_loader(input + input_segment_offset, segment_size); + VectorizedLoader gate_loader( input + input_segment_offset + segment_size, segment_size); - VectorizedStorer output_storer(output + output_segment_offset, - segment_size); - if (vector_idx >= act_loader.num_aligned_elements()) { - continue; - } - + VectorizedStorer output_storer(output + output_segment_offset, + segment_size); act_loader.load(vector_idx, segment_size); gate_loader.load(vector_idx, segment_size); const float scale = static_cast(act_scales[row]); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, act_loader.alignment(), - segment_size, &col)) { - const float unscaled = - gated_forward_value(static_cast(act_loader.separate()[lane]), - static_cast(gate_loader.separate()[lane]), param); - output_storer.separate()[lane] = static_cast(unscaled * scale); - } + const float unscaled = + gated_forward_value(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param); + output_storer.separate()[lane] = static_cast(unscaled * scale); } output_storer.store(vector_idx, segment_size); } } -template +template __global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *act_scales, - OutputT *output, const size_t total, + OutputT *output, const size_t rows, const size_t hidden, - const size_t num_vectors) { + const size_t num_vectors_per_row) { Empty empty = {}; - VectorizedLoader input_loader(input, total); - VectorizedStorer output_storer(output, total); - for (size_t vector_idx = blockIdx.x * blockDim.x + threadIdx.x; vector_idx < num_vectors; - vector_idx += gridDim.x * blockDim.x) { - if (vector_idx >= input_loader.num_aligned_elements()) { - continue; - } - input_loader.load(vector_idx, total); + const size_t total_vectors = rows * num_vectors_per_row; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_row; + const size_t row = tid / num_vectors_per_row; + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer output_storer(output + row * hidden, hidden); + input_loader.load(vector_idx, hidden); + const float scale = static_cast(act_scales[row]); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t idx = 0; - if (vector_lane_index(vector_idx, lane, input_loader.alignment(), total, - &idx)) { - const size_t row = idx / hidden; - const float unscaled = srelu(static_cast(input_loader.separate()[lane]), - empty); - const float scale = static_cast(act_scales[row]); - output_storer.separate()[lane] = static_cast(unscaled * scale); - } + const float unscaled = + srelu(static_cast(input_loader.separate()[lane]), empty); + output_storer.separate()[lane] = static_cast(unscaled * scale); } - output_storer.store(vector_idx, total); + output_storer.store(vector_idx, hidden); } } -template +template __global__ void scaled_gated_backward_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, @@ -241,92 +244,77 @@ __global__ void scaled_gated_backward_kernel( const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; const size_t output_segment_offset = row * hidden + segment * segment_size; - VectorizedLoader grad_loader(grad_output + output_segment_offset, - segment_size); - VectorizedLoader act_loader(input + input_segment_offset, - segment_size); - VectorizedLoader gate_loader( + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, segment_size); + VectorizedLoader gate_loader( input + input_segment_offset + segment_size, segment_size); - VectorizedStorer act_storer(grad_input + input_segment_offset, - segment_size); - VectorizedStorer gate_storer( + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( grad_input + input_segment_offset + segment_size, segment_size); - if (vector_idx >= act_loader.num_aligned_elements()) { - continue; - } - grad_loader.load(vector_idx, segment_size); act_loader.load(vector_idx, segment_size); gate_loader.load(vector_idx, segment_size); const float scale = static_cast(act_scales[row]); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, act_loader.alignment(), - segment_size, &col)) { - float dact = 0.0f; - float dgate = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(act_loader.separate()[lane]), - static_cast(gate_loader.separate()[lane]), param, &dact, - &dgate, &unscaled); - (void)unscaled; - const float grad = static_cast(grad_loader.separate()[lane]) * scale; - act_storer.separate()[lane] = static_cast(grad * dact); - gate_storer.separate()[lane] = static_cast(grad * dgate); - } + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + (void)unscaled; + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + act_storer.separate()[lane] = static_cast(grad * dact); + gate_storer.separate()[lane] = static_cast(grad * dgate); } act_storer.store(vector_idx, segment_size); gate_storer.store(vector_idx, segment_size); } } -template +template __global__ void scaled_srelu_backward_kernel(const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - const size_t total, const size_t hidden, - const size_t num_vectors) { + const size_t rows, const size_t hidden, + const size_t num_vectors_per_row) { Empty empty = {}; - VectorizedLoader grad_loader(grad_output, total); - VectorizedLoader input_loader(input, total); - VectorizedStorer grad_input_storer(grad_input, total); - for (size_t vector_idx = blockIdx.x * blockDim.x + threadIdx.x; vector_idx < num_vectors; - vector_idx += gridDim.x * blockDim.x) { - if (vector_idx >= input_loader.num_aligned_elements()) { - continue; - } - grad_loader.load(vector_idx, total); - input_loader.load(vector_idx, total); + const size_t total_vectors = rows * num_vectors_per_row; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_row; + const size_t row = tid / num_vectors_per_row; + VectorizedLoader grad_loader(grad_output + row * hidden, hidden); + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); + grad_loader.load(vector_idx, hidden); + input_loader.load(vector_idx, hidden); + const float scale = static_cast(act_scales[row]); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t idx = 0; - if (vector_lane_index(vector_idx, lane, input_loader.alignment(), total, - &idx)) { - const size_t row = idx / hidden; - const float scale = static_cast(act_scales[row]); - const float grad = static_cast(grad_loader.separate()[lane]) * scale; - grad_input_storer.separate()[lane] = - static_cast(grad * dsrelu( - static_cast(input_loader.separate()[lane]), - empty)); - } + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + grad_input_storer.separate()[lane] = + static_cast( + grad * dsrelu(static_cast(input_loader.separate()[lane]), empty)); } - grad_input_storer.store(vector_idx, total); + grad_input_storer.store(vector_idx, hidden); } } -template +template __global__ void scaled_gated_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { - __shared__ float smem[kReductionThreads]; + __shared__ float smem[kReductionWarps]; const size_t row = blockIdx.x; (void)rows; float scale_grad = 0.0f; + const float scale = static_cast(act_scales[row]); // Flatten (segment, vector) so interleave=32 distributes all row work across // the block instead of using only a few threads per small act/gate segment. @@ -337,112 +325,82 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( const size_t vector_idx = row_vector_idx % num_vectors_per_segment; const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; const size_t output_segment_offset = row * hidden + segment * segment_size; - VectorizedLoader grad_loader(grad_output + output_segment_offset, - segment_size); - VectorizedLoader act_loader(input + input_segment_offset, - segment_size); - VectorizedLoader gate_loader( + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, segment_size); + VectorizedLoader gate_loader( input + input_segment_offset + segment_size, segment_size); - VectorizedStorer act_storer(grad_input + input_segment_offset, - segment_size); - VectorizedStorer gate_storer( + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( grad_input + input_segment_offset + segment_size, segment_size); - if (vector_idx >= act_loader.num_aligned_elements()) { - continue; - } grad_loader.load(vector_idx, segment_size); act_loader.load(vector_idx, segment_size); gate_loader.load(vector_idx, segment_size); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, act_loader.alignment(), segment_size, - &col)) { - float dact = 0.0f; - float dgate = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(act_loader.separate()[lane]), - static_cast(gate_loader.separate()[lane]), param, &dact, - &dgate, &unscaled); - const float grad = static_cast(grad_loader.separate()[lane]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - act_storer.separate()[lane] = static_cast(scaled_grad * dact); - gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); - } + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scaled_grad = grad * scale; + act_storer.separate()[lane] = static_cast(scaled_grad * dact); + gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); } act_storer.store(vector_idx, segment_size); gate_storer.store(vector_idx, segment_size); } - smem[threadIdx.x] = scale_grad; - __syncthreads(); - for (int offset = kReductionThreads / 2; offset > 0; offset >>= 1) { - if (threadIdx.x < offset) { - smem[threadIdx.x] += smem[threadIdx.x + offset]; - } - __syncthreads(); - } + scale_grad = block_reduce_sum(scale_grad, smem); if (threadIdx.x == 0) { - grad_act_scales[row] = static_cast(smem[0]); + grad_act_scales[row] = static_cast(scale_grad); } } -template +template __global__ void scaled_srelu_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t num_vectors_per_row) { - __shared__ float smem[kReductionThreads]; + __shared__ float smem[kReductionWarps]; const size_t row = blockIdx.x; (void)rows; float scale_grad = 0.0f; Empty empty = {}; + const float scale = static_cast(act_scales[row]); - VectorizedLoader grad_loader(grad_output + row * hidden, hidden); - VectorizedLoader input_loader(input + row * hidden, hidden); - VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); + VectorizedLoader grad_loader(grad_output + row * hidden, hidden); + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); for (size_t vector_idx = threadIdx.x; vector_idx < num_vectors_per_row; vector_idx += blockDim.x) { - if (vector_idx >= input_loader.num_aligned_elements()) { - continue; - } grad_loader.load(vector_idx, hidden); input_loader.load(vector_idx, hidden); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, input_loader.alignment(), hidden, - &col)) { - const float unscaled = - srelu(static_cast(input_loader.separate()[lane]), empty); - const float grad = static_cast(grad_loader.separate()[lane]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - const float dact = - dsrelu(static_cast(input_loader.separate()[lane]), empty); - grad_input_storer.separate()[lane] = static_cast(scaled_grad * dact); - } + const float unscaled = + srelu(static_cast(input_loader.separate()[lane]), empty); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scaled_grad = grad * scale; + const float dact = + dsrelu(static_cast(input_loader.separate()[lane]), empty); + grad_input_storer.separate()[lane] = static_cast(scaled_grad * dact); } grad_input_storer.store(vector_idx, hidden); } - smem[threadIdx.x] = scale_grad; - __syncthreads(); - for (int offset = kReductionThreads / 2; offset > 0; offset >>= 1) { - if (threadIdx.x < offset) { - smem[threadIdx.x] += smem[threadIdx.x + offset]; - } - __syncthreads(); - } + scale_grad = block_reduce_sum(scale_grad, smem); if (threadIdx.x == 0) { - grad_act_scales[row] = static_cast(smem[0]); + grad_act_scales[row] = static_cast(scale_grad); } } @@ -566,11 +524,11 @@ void launch_scaled_gated_forward(const NVTETensor nvte_input, const NVTETensor n : segment_size; const int blocks = launch_blocks(rows * num_segments * num_vectors); if (use_vector) { - scaled_gated_forward_kernel + scaled_gated_forward_kernel <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { - scaled_gated_forward_kernel<1, true, InputT, ScaleT, OutputT, Act> + scaled_gated_forward_kernel<1, InputT, ScaleT, OutputT, Act> <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, segment_size, num_segments, segment_size, param); } @@ -599,20 +557,19 @@ void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor n const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto output_ptr = reinterpret_cast(output->data.dptr); - const size_t total = rows * hidden; - const auto align = CheckAlignment(total, nvec, input_ptr, output_ptr); + const auto align = row_vector_alignment(hidden, nvec, input_ptr, output_ptr); const bool use_vector = align == Alignment::SAME_ALIGNED; const size_t num_vectors = - use_vector ? get_num_aligned_elements(input_ptr, total, nvec, sizeof(InputT)) : total; - const int blocks = launch_blocks(num_vectors); + use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) : hidden; + const int blocks = launch_blocks(rows * num_vectors); if (use_vector) { - scaled_srelu_forward_kernel - <<>>(input_ptr, scale_ptr, output_ptr, total, hidden, + scaled_srelu_forward_kernel + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, num_vectors); } else { - scaled_srelu_forward_kernel<1, true, InputT, ScaleT, OutputT> - <<>>(input_ptr, scale_ptr, output_ptr, total, hidden, - total); + scaled_srelu_forward_kernel<1, InputT, ScaleT, OutputT> + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, + hidden); } }); }); @@ -664,12 +621,12 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET if (grad_act_scales == nullptr) { const int blocks = launch_blocks(rows * num_segments * num_vectors); if (use_vector) { - scaled_gated_backward_kernel + scaled_gated_backward_kernel <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { - scaled_gated_backward_kernel<1, true, GradT, InputT, ScaleT, OutputT, Act> + scaled_gated_backward_kernel<1, GradT, InputT, ScaleT, OutputT, Act> <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, rows, hidden, segment_size, num_segments, segment_size, param); @@ -680,13 +637,13 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET reinterpret_cast(grad_act_scales->data.dptr); if (use_vector) { scaled_gated_backward_with_scale_grad_kernel< - nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + nvec, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { scaled_gated_backward_with_scale_grad_kernel< - 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + 1, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, segment_size, param); @@ -728,42 +685,36 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); + const auto align = row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, + grad_input_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) + : hidden; if (grad_act_scales == nullptr) { - const size_t total = rows * hidden; - const auto align = CheckAlignment(total, nvec, grad_ptr, input_ptr, grad_input_ptr); - const bool use_vector = align == Alignment::SAME_ALIGNED; - const size_t num_vectors = - use_vector ? get_num_aligned_elements(input_ptr, total, nvec, sizeof(InputT)) - : total; - const int blocks = launch_blocks(num_vectors); + const int blocks = launch_blocks(rows * num_vectors); if (use_vector) { - scaled_srelu_backward_kernel + scaled_srelu_backward_kernel <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, - total, hidden, num_vectors); + rows, hidden, num_vectors); } else { - scaled_srelu_backward_kernel<1, true, GradT, InputT, ScaleT, OutputT> + scaled_srelu_backward_kernel<1, GradT, InputT, ScaleT, OutputT> <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, - total, hidden, total); + rows, hidden, hidden); } } else { - const auto align = row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, - grad_input_ptr); - const bool use_vector = align == Alignment::SAME_ALIGNED; - const size_t num_vectors = - use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) - : hidden; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { auto grad_act_scales_ptr = reinterpret_cast(grad_act_scales->data.dptr); if (use_vector) { scaled_srelu_backward_with_scale_grad_kernel< - nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT> + nvec, GradT, InputT, ScaleT, OutputT, GradScaleT> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, num_vectors); } else { scaled_srelu_backward_with_scale_grad_kernel< - 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT> + 1, GradT, InputT, ScaleT, OutputT, GradScaleT> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, hidden); From c73c8ea1c91db6b81c5ff2e8421155eaf52f6dd7 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Jun 2026 22:31:36 -0700 Subject: [PATCH 5/8] fix unit test failure Signed-off-by: Zhongbo Zhu --- tests/cpp/operator/test_scaled_activation.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/cpp/operator/test_scaled_activation.cu b/tests/cpp/operator/test_scaled_activation.cu index 80641e3c5f..72a64a3c04 100644 --- a/tests/cpp/operator/test_scaled_activation.cu +++ b/tests/cpp/operator/test_scaled_activation.cu @@ -231,10 +231,12 @@ void run_scaled_activation_test(ScaledActivationCase activation, const size_t ro atol = 5e-5; rtol = 5e-5; } - compareResults("scaled_activation_output", output, ref_output.get(), atol, rtol); - compareResults("scaled_activation_grad_input", grad_input, ref_grad_input.get(), atol, rtol); + compareResults("scaled_activation_output", output, ref_output.get(), true, atol, rtol); + compareResults("scaled_activation_grad_input", grad_input, ref_grad_input.get(), true, atol, + rtol); if (compute_grad_scales) { - compareResults("scaled_activation_grad_scales", grad_scales, ref_grad_scales.get(), atol, rtol); + compareResults("scaled_activation_grad_scales", grad_scales, ref_grad_scales.get(), true, atol, + rtol); } } From 3eb18a6e05b84cba0983a1e754e572ab2d4d1761 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 07:16:09 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/activation/scaled_activation.cu | 131 ++++++++---------- .../include/transformer_engine/activation.h | 12 +- 2 files changed, 63 insertions(+), 80 deletions(-) diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu index d056966281..73df92338c 100644 --- a/transformer_engine/common/activation/scaled_activation.cu +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -91,8 +91,7 @@ __device__ __forceinline__ float gated_forward_value(const float act_in, const f template __device__ __forceinline__ void gated_backward_values(const float act_in, const float gate_in, const ClampedSwiGLUParam ¶m, float *dact, - float *dgate, - float *unscaled) { + float *dgate, float *unscaled) { if constexpr (Act == ScaledActivation::kSwiGLU) { const float sigmoid = sigmoid_from_float(act_in); const float act = act_in * sigmoid; @@ -171,9 +170,8 @@ Alignment row_vector_alignment(const size_t lead_dim, const int nvec, const Ptrs template __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *act_scales, - OutputT *output, const size_t rows, - const size_t hidden, const size_t segment_size, - const size_t num_segments, + OutputT *output, const size_t rows, const size_t hidden, + const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { const size_t total_vectors = rows * num_segments * num_vectors_per_segment; @@ -186,8 +184,8 @@ __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *a const size_t output_segment_offset = row * hidden + segment * segment_size; VectorizedLoader act_loader(input + input_segment_offset, segment_size); - VectorizedLoader gate_loader( - input + input_segment_offset + segment_size, segment_size); + VectorizedLoader gate_loader(input + input_segment_offset + segment_size, + segment_size); VectorizedStorer output_storer(output + output_segment_offset, segment_size); act_loader.load(vector_idx, segment_size); @@ -206,8 +204,7 @@ __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *a template __global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *act_scales, - OutputT *output, const size_t rows, - const size_t hidden, + OutputT *output, const size_t rows, const size_t hidden, const size_t num_vectors_per_row) { Empty empty = {}; const size_t total_vectors = rows * num_vectors_per_row; @@ -231,10 +228,12 @@ __global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *a template -__global__ void scaled_gated_backward_kernel( - const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, - const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { +__global__ void scaled_gated_backward_kernel(const GradT *grad_output, const InputT *input, + const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden, + const size_t segment_size, const size_t num_segments, + const size_t num_vectors_per_segment, + const ClampedSwiGLUParam param) { const size_t total_vectors = rows * num_segments * num_vectors_per_segment; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; tid += gridDim.x * blockDim.x) { @@ -247,8 +246,8 @@ __global__ void scaled_gated_backward_kernel( VectorizedLoader grad_loader(grad_output + output_segment_offset, segment_size); VectorizedLoader act_loader(input + input_segment_offset, segment_size); - VectorizedLoader gate_loader( - input + input_segment_offset + segment_size, segment_size); + VectorizedLoader gate_loader(input + input_segment_offset + segment_size, + segment_size); VectorizedStorer act_storer(grad_input + input_segment_offset, segment_size); VectorizedStorer gate_storer( @@ -295,9 +294,8 @@ __global__ void scaled_srelu_backward_kernel(const GradT *grad_output, const Inp #pragma unroll for (int lane = 0; lane < nvec; ++lane) { const float grad = static_cast(grad_loader.separate()[lane]) * scale; - grad_input_storer.separate()[lane] = - static_cast( - grad * dsrelu(static_cast(input_loader.separate()[lane]), empty)); + grad_input_storer.separate()[lane] = static_cast( + grad * dsrelu(static_cast(input_loader.separate()[lane]), empty)); } grad_input_storer.store(vector_idx, hidden); } @@ -307,8 +305,8 @@ template __global__ void scaled_gated_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, - const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t segment_size, + const size_t num_segments, const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { __shared__ float smem[kReductionWarps]; const size_t row = blockIdx.x; @@ -328,8 +326,8 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( VectorizedLoader grad_loader(grad_output + output_segment_offset, segment_size); VectorizedLoader act_loader(input + input_segment_offset, segment_size); - VectorizedLoader gate_loader( - input + input_segment_offset + segment_size, segment_size); + VectorizedLoader gate_loader(input + input_segment_offset + segment_size, + segment_size); VectorizedStorer act_storer(grad_input + input_segment_offset, segment_size); VectorizedStorer gate_storer( @@ -450,9 +448,8 @@ void check_grad_scale_tensor(const Tensor *grad_act_scales, const size_t rows, void check_gated_backward_tensors(const Tensor *grad_output, const Tensor *input, const Tensor *act_scales, const Tensor *grad_input, - const Tensor *grad_act_scales, - const int64_t glu_interleave_size, const char *api_name, - size_t *rows, size_t *hidden) { + const Tensor *grad_act_scales, const int64_t glu_interleave_size, + const char *api_name, size_t *rows, size_t *hidden) { const auto grad_dims = grad_output->flat_2d_dims(); const auto input_dims = input->flat_2d_dims(); const auto grad_input_dims = grad_input->flat_2d_dims(); @@ -475,8 +472,8 @@ void check_gated_backward_tensors(const Tensor *grad_output, const Tensor *input void check_unary_backward_tensors(const Tensor *grad_output, const Tensor *input, const Tensor *act_scales, const Tensor *grad_input, - const Tensor *grad_act_scales, const char *api_name, - size_t *rows, size_t *hidden) { + const Tensor *grad_act_scales, const char *api_name, size_t *rows, + size_t *hidden) { const auto grad_dims = grad_output->flat_2d_dims(); const auto input_dims = input->flat_2d_dims(); const auto grad_input_dims = grad_input->flat_2d_dims(); @@ -507,17 +504,15 @@ void launch_scaled_gated_forward(const NVTETensor nvte_input, const NVTETensor n TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { - constexpr int nvec = - sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + constexpr int nvec = sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto output_ptr = reinterpret_cast(output->data.dptr); const size_t segment_size = glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; - const auto align = - row_vector_alignment(segment_size, nvec, input_ptr, input_ptr + segment_size, - output_ptr); + const auto align = row_vector_alignment(segment_size, nvec, input_ptr, + input_ptr + segment_size, output_ptr); const bool use_vector = align == Alignment::SAME_ALIGNED; const size_t num_vectors = use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) @@ -552,8 +547,7 @@ void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor n TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { - constexpr int nvec = - sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + constexpr int nvec = sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto output_ptr = reinterpret_cast(output->data.dptr); @@ -567,9 +561,8 @@ void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor n <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, num_vectors); } else { - scaled_srelu_forward_kernel<1, InputT, ScaleT, OutputT> - <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, - hidden); + scaled_srelu_forward_kernel<1, InputT, ScaleT, OutputT><<>>( + input_ptr, scale_ptr, output_ptr, rows, hidden, hidden); } }); }); @@ -581,9 +574,8 @@ template void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTETensor nvte_input, const NVTETensor nvte_act_scales, NVTETensor nvte_grad_input, NVTETensor nvte_grad_act_scales, - const int64_t glu_interleave_size, - const ClampedSwiGLUParam param, cudaStream_t stream, - const char *api_name) { + const int64_t glu_interleave_size, const ClampedSwiGLUParam param, + cudaStream_t stream, const char *api_name) { const Tensor *grad_output = convertNVTETensorCheck(nvte_grad_output); const Tensor *input = convertNVTETensorCheck(nvte_input); const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); @@ -600,8 +592,7 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { - constexpr int nvec = sizeof(GradT) == sizeof(InputT) && - sizeof(InputT) == sizeof(OutputT) + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); @@ -611,9 +602,9 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET const size_t segment_size = glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; - const auto align = row_vector_alignment( - segment_size, nvec, grad_ptr, input_ptr, input_ptr + segment_size, grad_input_ptr, - grad_input_ptr + segment_size); + const auto align = row_vector_alignment(segment_size, nvec, grad_ptr, input_ptr, + input_ptr + segment_size, grad_input_ptr, + grad_input_ptr + segment_size); const bool use_vector = align == Alignment::SAME_ALIGNED; const size_t num_vectors = use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) @@ -633,17 +624,16 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET } } else { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { - auto grad_act_scales_ptr = - reinterpret_cast(grad_act_scales->data.dptr); + auto grad_act_scales_ptr = reinterpret_cast(grad_act_scales->data.dptr); if (use_vector) { - scaled_gated_backward_with_scale_grad_kernel< - nvec, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + scaled_gated_backward_with_scale_grad_kernel <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { - scaled_gated_backward_with_scale_grad_kernel< - 1, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + scaled_gated_backward_with_scale_grad_kernel<1, GradT, InputT, ScaleT, OutputT, + GradScaleT, Act> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, segment_size, param); @@ -677,16 +667,15 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { - constexpr int nvec = sizeof(GradT) == sizeof(InputT) && - sizeof(InputT) == sizeof(OutputT) + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); - const auto align = row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, - grad_input_ptr); + const auto align = + row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, grad_input_ptr); const bool use_vector = align == Alignment::SAME_ALIGNED; const size_t num_vectors = use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) @@ -704,17 +693,16 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET } } else { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { - auto grad_act_scales_ptr = - reinterpret_cast(grad_act_scales->data.dptr); + auto grad_act_scales_ptr = reinterpret_cast(grad_act_scales->data.dptr); if (use_vector) { - scaled_srelu_backward_with_scale_grad_kernel< - nvec, GradT, InputT, ScaleT, OutputT, GradScaleT> + scaled_srelu_backward_with_scale_grad_kernel <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, num_vectors); } else { - scaled_srelu_backward_with_scale_grad_kernel< - 1, GradT, InputT, ScaleT, OutputT, GradScaleT> + scaled_srelu_backward_with_scale_grad_kernel<1, GradT, InputT, ScaleT, OutputT, + GradScaleT> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, hidden); @@ -742,16 +730,15 @@ void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVT input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_swiglu"); } -void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, - const NVTETensor act_scales, NVTETensor grad_input, - NVTETensor grad_act_scales, int64_t glu_interleave_size, - cudaStream_t stream) { +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, + int64_t glu_interleave_size, cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_dswiglu); using namespace transformer_engine; ClampedSwiGLUParam param = {}; - launch_scaled_gated_backward( - grad, input, act_scales, grad_input, grad_act_scales, glu_interleave_size, param, stream, - "nvte_scaled_dswiglu"); + launch_scaled_gated_backward(grad, input, act_scales, grad_input, + grad_act_scales, glu_interleave_size, + param, stream, "nvte_scaled_dswiglu"); } void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales, @@ -762,8 +749,7 @@ void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_sca using namespace transformer_engine; ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; launch_scaled_gated_forward( - input, act_scales, output, glu_interleave_size, param, stream, - "nvte_scaled_clamped_swiglu"); + input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_clamped_swiglu"); } void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, @@ -786,9 +772,8 @@ void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTE launch_scaled_srelu_forward(input, act_scales, output, stream, "nvte_scaled_srelu"); } -void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, - const NVTETensor act_scales, NVTETensor grad_input, - NVTETensor grad_act_scales, cudaStream_t stream) { +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_dsrelu); using namespace transformer_engine; launch_scaled_srelu_backward(grad, input, act_scales, grad_input, grad_act_scales, stream, diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index f1485057ec..ed90428f8c 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -521,10 +521,9 @@ void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTE * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, - const NVTETensor act_scales, NVTETensor grad_input, - NVTETensor grad_act_scales, int64_t glu_interleave_size, - cudaStream_t stream); +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, + int64_t glu_interleave_size, cudaStream_t stream); /*! \brief Computes ScaledClampedSwiGLU backward without materializing GLU deinterleave. * @@ -608,9 +607,8 @@ void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTE * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, - const NVTETensor act_scales, NVTETensor grad_input, - NVTETensor grad_act_scales, cudaStream_t stream); +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, cudaStream_t stream); #ifdef __cplusplus } // extern "C" From 7044104858f3ded88a529551b68b05d858e78a03 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 16 Jun 2026 00:31:09 -0700 Subject: [PATCH 7/8] megacpp grouped mlp Signed-off-by: Zhongbo Zhu --- tests/pytorch/megacpp/test_grouped_mlp.py | 484 +++++++++++ transformer_engine/pytorch/csrc/extensions.h | 23 + .../pytorch/csrc/extensions/pybind.cpp | 19 + .../pytorch/csrc/megacpp/grouped_mlp.cpp | 812 ++++++++++++++++++ .../pytorch/ops/fused/__init__.py | 6 + .../ops/fused/backward_grouped_mlp_megacpp.py | 394 +++++++++ .../ops/fused/forward_grouped_mlp_megacpp.py | 421 +++++++++ 7 files changed, 2159 insertions(+) create mode 100644 tests/pytorch/megacpp/test_grouped_mlp.py create mode 100644 transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp create mode 100644 transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py create mode 100644 transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py diff --git a/tests/pytorch/megacpp/test_grouped_mlp.py b/tests/pytorch/megacpp/test_grouped_mlp.py new file mode 100644 index 0000000000..b056af1978 --- /dev/null +++ b/tests/pytorch/megacpp/test_grouped_mlp.py @@ -0,0 +1,484 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops + + +_HIDDEN_SIZE = 512 +_FFN_HIDDEN_SIZE = 256 + + +def _megacpp_available() -> tuple[bool, str]: + if not torch.cuda.is_available(): + return False, "CUDA is required" + if not te.is_bf16_available(): + return False, "BF16 is required" + if torch.cuda.get_device_capability() < (10, 0): + return False, "megacpp grouped MLP uses SM100 grouped GEMM" + if not te_ops.fused.ForwardGroupedMLP_MegaCpp.is_supported(): + return False, "ForwardGroupedMLP_MegaCpp is not supported" + if not te_ops.fused.BackwardGroupedMLP_MegaCpp.is_supported(): + return False, "BackwardGroupedMLP_MegaCpp is not supported" + return True, "" + + +_AVAILABLE, _SKIP_REASON = _megacpp_available() +pytestmark = pytest.mark.skipif(not _AVAILABLE, reason=_SKIP_REASON) + + +def _make_grouped_mlp( + *, + num_groups: int, + hidden_size: int, + ffn_hidden_size: int, + activation_kind: str, + bias: bool, + delay_wgrad_compute: bool, + accumulate_into_main_grad: bool, + glu_interleave_size: int | None, + single_grouped_param: bool, +) -> te_ops.Sequential: + gated_activation = activation_kind in ("scaled_swiglu", "scaled_clamped_qgeglu") + fc1_out_features = 2 * ffn_hidden_size if gated_activation else ffn_hidden_size + fc1 = te_ops.GroupedLinear( + num_groups, + hidden_size, + fc1_out_features, + bias=bias, + device="cuda", + dtype=torch.bfloat16, + delay_wgrad_compute=delay_wgrad_compute, + accumulate_into_main_grad=accumulate_into_main_grad, + single_grouped_weight=single_grouped_param, + single_grouped_bias=single_grouped_param and bias, + ) + if activation_kind == "scaled_swiglu": + act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + elif activation_kind == "scaled_clamped_qgeglu": + act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + elif activation_kind == "scaled_srelu": + act = te_ops.ScaledSReLU() + else: + raise ValueError(f"Unsupported test activation_kind={activation_kind}.") + fc2 = te_ops.GroupedLinear( + num_groups, + ffn_hidden_size, + hidden_size, + bias=bias, + device="cuda", + dtype=torch.bfloat16, + delay_wgrad_compute=delay_wgrad_compute, + accumulate_into_main_grad=accumulate_into_main_grad, + single_grouped_weight=single_grouped_param, + single_grouped_bias=single_grouped_param and bias, + ) + return te_ops.Sequential(fc1, act, fc2) + + +def _copy_grouped_mlp_params(dst: te_ops.Sequential, src: te_ops.Sequential) -> None: + with torch.no_grad(): + for dst_linear, src_linear in ((dst[0], src[0]), (dst[2], src[2])): + if dst_linear.single_grouped_weight: + dst_linear.weight.rowwise_data.copy_(src_linear.weight.rowwise_data) + if dst_linear.has_bias: + dst_linear.bias.rowwise_data.copy_(src_linear.bias.rowwise_data) + else: + for group_idx in range(dst_linear.num_groups): + getattr(dst_linear, f"weight{group_idx}").copy_( + getattr(src_linear, f"weight{group_idx}") + ) + if dst_linear.has_bias: + getattr(dst_linear, f"bias{group_idx}").copy_( + getattr(src_linear, f"bias{group_idx}") + ) + + +def _init_main_grads(module: te_ops.Sequential, dtype: torch.dtype) -> None: + for linear in (module[0], module[2]): + if linear.single_grouped_weight: + linear.weight.main_grad = torch.zeros( + linear.num_groups, + linear.out_features, + linear.in_features, + device="cuda", + dtype=dtype, + ) + else: + for group_idx in range(linear.num_groups): + weight = getattr(linear, f"weight{group_idx}") + weight.main_grad = torch.zeros_like(weight, dtype=dtype) + + +def _run_grouped_mlp( + module: te_ops.Sequential, + x: torch.Tensor, + split_sizes: torch.Tensor, + act_scales: torch.Tensor, + dy: torch.Tensor, + *, + delay_wgrad_compute: bool, +) -> torch.Tensor: + y = module(x, split_sizes, act_scales, split_sizes) + y.backward(dy) + if delay_wgrad_compute: + module[0].backward_dw() + module[2].backward_dw() + return y + + +def _assert_grouped_mlp_close( + test: te_ops.Sequential, + ref: te_ops.Sequential, + *, + accumulate_into_main_grad: bool, +) -> None: + for test_linear, ref_linear in ((test[0], ref[0]), (test[2], ref[2])): + if test_linear.single_grouped_weight: + if accumulate_into_main_grad: + torch.testing.assert_close( + test_linear.weight.main_grad, + ref_linear.weight.main_grad, + rtol=2e-2, + atol=2e-2, + ) + else: + torch.testing.assert_close( + test_linear.weight.grad, + ref_linear.weight.grad, + rtol=2e-2, + atol=2e-2, + ) + if test_linear.has_bias: + torch.testing.assert_close( + test_linear.bias.grad, + ref_linear.bias.grad, + rtol=2e-2, + atol=2e-2, + ) + continue + for group_idx in range(test_linear.num_groups): + if accumulate_into_main_grad: + torch.testing.assert_close( + getattr(test_linear, f"weight{group_idx}").main_grad, + getattr(ref_linear, f"weight{group_idx}").main_grad, + rtol=2e-2, + atol=2e-2, + ) + else: + torch.testing.assert_close( + getattr(test_linear, f"weight{group_idx}").grad, + getattr(ref_linear, f"weight{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + if test_linear.has_bias: + torch.testing.assert_close( + getattr(test_linear, f"bias{group_idx}").grad, + getattr(ref_linear, f"bias{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + + +def _assert_grouped_mlp_nonzero_expert_grads_close( + test: te_ops.Sequential, + ref: te_ops.Sequential, + split_sizes: list[int], +) -> None: + """Compare only non-empty experts; zero-token expert grads may be unwritten.""" + for test_linear, ref_linear in ((test[0], ref[0]), (test[2], ref[2])): + for group_idx, split_size in enumerate(split_sizes): + if split_size == 0: + continue + torch.testing.assert_close( + getattr(test_linear, f"weight{group_idx}").grad, + getattr(ref_linear, f"weight{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + if test_linear.has_bias: + torch.testing.assert_close( + getattr(test_linear, f"bias{group_idx}").grad, + getattr(ref_linear, f"bias{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + + +def _assert_valid_prefix_close( + test: torch.Tensor, + ref: torch.Tensor, + valid_tokens: int, +) -> None: + """Paged-stashed buffers only guarantee correctness in the valid token prefix.""" + if valid_tokens == 0: + return + torch.testing.assert_close(test[:valid_tokens], ref[:valid_tokens], rtol=2e-2, atol=2e-2) + + +def _make_split_tensor( + split_sizes: list[int], + *, + dtype: torch.dtype = torch.int64, + device: str = "cuda", +) -> torch.Tensor: + return torch.tensor(split_sizes, dtype=dtype, device=device) + + +def _run_megacpp_against_python( + *, + split_sizes_list: list[int], + physical_tokens: int, + split_dtype: torch.dtype, + split_device: str, + bias: bool = True, + glu_interleave_size: int | None = None, + activation_kind: str = "scaled_swiglu", + single_grouped_param: bool = False, + accumulate_into_main_grad: bool = False, + main_grad_dtype: torch.dtype | None = None, + compare_zero_expert_grads: bool = True, + monkeypatch, +) -> None: + num_groups = len(split_sizes_list) + valid_tokens = sum(split_sizes_list) + assert physical_tokens >= valid_tokens + if single_grouped_param: + monkeypatch.setenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "1") + split_sizes = _make_split_tensor(split_sizes_list, dtype=split_dtype, device=split_device) + ref = _make_grouped_mlp( + num_groups=num_groups, + hidden_size=_HIDDEN_SIZE, + ffn_hidden_size=_FFN_HIDDEN_SIZE, + activation_kind=activation_kind, + bias=bias, + delay_wgrad_compute=False, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=glu_interleave_size, + single_grouped_param=single_grouped_param, + ) + test = _make_grouped_mlp( + num_groups=num_groups, + hidden_size=_HIDDEN_SIZE, + ffn_hidden_size=_FFN_HIDDEN_SIZE, + activation_kind=activation_kind, + bias=bias, + delay_wgrad_compute=False, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=glu_interleave_size, + single_grouped_param=single_grouped_param, + ) + _copy_grouped_mlp_params(test, ref) + if accumulate_into_main_grad: + if main_grad_dtype is None: + raise ValueError("main_grad_dtype must be set when using Megatron-owned main_grad.") + _init_main_grads(ref, main_grad_dtype) + _init_main_grads(test, main_grad_dtype) + + # Paged stashing passes a static physical buffer to the op while m_splits + # describe only the valid prefix. Rows after sum(m_splits) are garbage and + # must not affect outputs/gradients for the valid prefix. + x_ref = torch.randn( + physical_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() + x_test = x_ref.detach().clone().requires_grad_() + act_scales_ref = torch.rand( + physical_tokens, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() + act_scales_test = act_scales_ref.detach().clone().requires_grad_() + dy = torch.randn(physical_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16) + + monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "0") + y_ref = _run_grouped_mlp( + ref, + x_ref, + split_sizes, + act_scales_ref, + dy, + delay_wgrad_compute=False, + ) + monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "1") + y_test = _run_grouped_mlp( + test, + x_test, + split_sizes, + act_scales_test, + dy, + delay_wgrad_compute=False, + ) + + fuser = test._module_groups[0] + assert isinstance(fuser._forward_ops[0][0], te_ops.fused.ForwardGroupedMLP_MegaCpp) + assert isinstance(fuser._backward_ops[0][0], te_ops.fused.BackwardGroupedMLP_MegaCpp) + + _assert_valid_prefix_close(y_test, y_ref, valid_tokens) + _assert_valid_prefix_close(x_test.grad, x_ref.grad, valid_tokens) + _assert_valid_prefix_close( + act_scales_test.grad, + act_scales_ref.grad, + valid_tokens, + ) + if valid_tokens == physical_tokens and compare_zero_expert_grads: + _assert_grouped_mlp_close(test, ref, accumulate_into_main_grad=accumulate_into_main_grad) + elif valid_tokens > 0 and not single_grouped_param and not accumulate_into_main_grad: + _assert_grouped_mlp_nonzero_expert_grads_close(test, ref, split_sizes_list) + + +@pytest.mark.parametrize( + "single_grouped_param", + [False, True], + ids=["discrete_weight", "packed_weight"], +) +@pytest.mark.parametrize( + "accumulate_into_main_grad,main_grad_dtype", + [ + pytest.param(False, None, id="cpp_allocated_wgrad"), + pytest.param(True, torch.bfloat16, id="megatron_main_grad_bf16"), + pytest.param(True, torch.float32, id="megatron_main_grad_fp32"), + ], +) +def test_megacpp_grouped_mlp_wgrad_storage_matches_python( + single_grouped_param, + accumulate_into_main_grad, + main_grad_dtype, + monkeypatch, +): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=torch.int64, + split_device="cuda", + single_grouped_param=single_grouped_param, + accumulate_into_main_grad=accumulate_into_main_grad, + main_grad_dtype=main_grad_dtype, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize( + "split_dtype,split_device", + [ + pytest.param(torch.int64, "cuda", id="i64_cuda"), + pytest.param(torch.int32, "cuda", id="i32_cuda"), + pytest.param(torch.int64, "cpu", id="i64_cpu"), + ], +) +def test_megacpp_grouped_mlp_split_source_matches_python( + split_dtype, + split_device, + monkeypatch, +): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=split_dtype, + split_device=split_device, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize( + "activation_kind", + ["scaled_swiglu", "scaled_srelu", "scaled_clamped_qgeglu"], + ids=["swiglu", "srelu", "clamped_qgeglu"], +) +@pytest.mark.parametrize( + "glu_interleave_size", + [None, 32], + ids=["no_interleave", "interleave_32"], +) +def test_megacpp_grouped_mlp_activation_matches_python( + activation_kind, + glu_interleave_size, + monkeypatch, +): + if activation_kind == "scaled_srelu" and glu_interleave_size is not None: + pytest.skip("ScaledSReLU is not a GLU activation.") + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=torch.int64, + split_device="cuda", + activation_kind=activation_kind, + glu_interleave_size=glu_interleave_size, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize("bias", [True, False], ids=["bias", "no_bias"]) +def test_megacpp_grouped_mlp_bias_matches_python(bias, monkeypatch): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=torch.int64, + split_device="cuda", + bias=bias, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize( + "split_sizes_list,physical_tokens", + [ + pytest.param([256, 256, 256, 256], 1024, id="even"), + pytest.param([0, 256, 256, 512], 1024, id="zero_front"), + pytest.param([256, 0, 256, 512], 1024, id="zero_middle"), + pytest.param([256, 256, 512, 0], 1024, id="zero_end"), + pytest.param([256, 256], 1024, id="paged_stashing_even_with_garbage"), + pytest.param([0, 256, 256], 1024, id="paged_stashing_zero_front_with_garbage"), + pytest.param([256, 0, 256], 1024, id="paged_stashing_zero_middle_with_garbage"), + pytest.param([256, 256, 0], 1024, id="paged_stashing_zero_end_with_garbage"), + pytest.param([0, 0, 0, 0], 1024, id="paged_stashing_zero_tokens_all_nonempty_input"), + ], +) +def test_megacpp_grouped_mlp_split_edge_cases( + split_sizes_list, + physical_tokens, + monkeypatch, +): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=split_sizes_list, + physical_tokens=physical_tokens, + split_dtype=torch.int64, + split_device="cuda", + compare_zero_expert_grads=False, + monkeypatch=monkeypatch, + ) + + +def test_megacpp_grouped_mlp_delay_wgrad_raises(monkeypatch): + torch.manual_seed(1234) + num_groups = 3 + split_sizes = torch.tensor([256, 256, 512], dtype=torch.int64, device="cuda") + total_tokens = int(split_sizes.sum().item()) + module = _make_grouped_mlp( + num_groups=num_groups, + hidden_size=_HIDDEN_SIZE, + ffn_hidden_size=_FFN_HIDDEN_SIZE, + activation_kind="scaled_swiglu", + bias=True, + delay_wgrad_compute=True, + accumulate_into_main_grad=False, + glu_interleave_size=None, + single_grouped_param=False, + ) + x = torch.randn( + total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() + act_scales = torch.rand(total_tokens, device="cuda", dtype=torch.bfloat16).requires_grad_() + dy = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16) + + monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "1") + with pytest.raises(ValueError, match="delay_wgrad_compute"): + y = module(x, split_sizes, act_scales, split_sizes) + y.backward(dy) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..3561254a7c 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -185,6 +185,29 @@ py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, p at::Tensor workspace_cublas, bool use_split_accumulator, int math_sm_count); +/*************************************************************************************************** + * Mega C++ grouped MLP + **************************************************************************************************/ + +std::vector megacpp_grouped_mlp_forward( + const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes, + py::handle fc1_weight, py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, + const std::optional &act_scales, const std::string &activation, + int64_t glu_interleave_size, double activation_limit, double activation_alpha, + double activation_glu_linear_offset, py::handle gemm_scratch); + +py::tuple megacpp_grouped_mlp_backward( + const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes, + const at::Tensor &x_offsets, const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, + const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, const at::Tensor &x, + const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, + const std::optional &act_scales, py::handle fc1_weight, py::handle fc2_weight, + py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, + py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, + bool input_requires_grad, py::handle gemm_scratch); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..34ad560ae0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -357,6 +357,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("te_general_grouped_gemm_for_discrete_out", &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, "Grouped GEMM for discrete output list"); + m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward, + "Mega C++ grouped MLP forward", py::arg("input"), py::arg("act_dtype"), + py::arg("split_sizes"), py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"), + py::arg("fc2_bias"), py::arg("act_scales"), py::arg("activation"), + py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0, + py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0, + py::arg("gemm_scratch") = py::none()); + m.def("megacpp_grouped_mlp_backward", &transformer_engine::pytorch::megacpp_grouped_mlp_backward, + "Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("act_dtype"), + py::arg("split_sizes"), py::arg("x_offsets"), py::arg("fc1_offsets"), + py::arg("fc2_offsets"), py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"), + py::arg("fc1_activation_input"), py::arg("fc2_x"), py::arg("act_scales"), + py::arg("fc1_weight"), py::arg("fc2_weight"), py::arg("fc1_wgrad_output"), + py::arg("fc1_compute_wgrad"), py::arg("fc1_accumulate_wgrad"), py::arg("fc2_wgrad_output"), + py::arg("fc2_compute_wgrad"), py::arg("fc2_accumulate_wgrad"), py::arg("activation"), + py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0, + py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0, + py::arg("act_scales_requires_grad") = true, py::arg("input_requires_grad") = true, + py::arg("gemm_scratch") = py::none()); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp new file mode 100644 index 0000000000..4292adb349 --- /dev/null +++ b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp @@ -0,0 +1,812 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../extensions.h" +#include "../pybind.h" +#include "common/util/cuda_runtime.h" +#include "common/util/system.h" +#include "transformer_engine/activation.h" +#include "transformer_engine/gemm.h" +#include "transformer_engine/transformer_engine.h" + +namespace py = pybind11; + +namespace transformer_engine::pytorch { +namespace { + +constexpr int64_t kGroupedGemmCublasWorkspaceSize = 32 * 1024 * 1024 + 1024; + +bool is_none(py::handle obj) { return obj.is_none(); } + +std::vector tensor_shape_1d(const at::Tensor &tensor) { + return {static_cast(tensor.numel())}; +} + +at::Tensor maybe_cast_dtype(const at::Tensor &tensor, at::ScalarType dtype) { + at::Tensor out = tensor; + if (out.scalar_type() != dtype) { + out = out.to(out.options().dtype(dtype)); + } + return out; +} + +void check_contiguous(const at::Tensor &tensor, const std::string &name) { + NVTE_CHECK(tensor.is_contiguous(), name, " must be contiguous."); +} + +size_t num_groups_from_prepared_split_sizes(const at::Tensor &split_sizes, + const c10::Device &device) { + NVTE_CHECK(split_sizes.dim() == 1, "split_sizes must be a 1D tensor."); + NVTE_CHECK(split_sizes.device() == device, "split_sizes must be on the current CUDA device."); + NVTE_CHECK(split_sizes.scalar_type() == at::kLong, + "split_sizes must be the int64 CUDA tensor returned by splits_to_offsets_multi."); + return static_cast(split_sizes.numel()); +} + +GroupedTensorWrapper make_grouped_tensor(const at::Tensor &data, + const at::Tensor &prepared_split_sizes, + const at::Tensor &tensor_offsets, + int64_t logical_last_dim) { + const auto num_groups = static_cast(prepared_split_sizes.numel()); + NVTE_CHECK(data.numel() % logical_last_dim == 0, + "Grouped tensor storage is not divisible by logical last dimension."); + const auto total_tokens = static_cast(data.numel() / logical_last_dim); + auto grouped = GroupedTensorWrapper( + num_groups, std::vector{total_tokens, static_cast(logical_last_dim)}); + grouped.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), + tensor_shape_1d(data)); + grouped.set_first_dims(prepared_split_sizes.data_ptr(), DType::kInt64, + std::vector{num_groups}); + grouped.set_tensor_offsets(tensor_offsets.data_ptr(), DType::kInt64, + std::vector{num_groups + 1}); + return grouped; +} + +GroupedTensorWrapper make_uniform_grouped_tensor(at::Tensor data, size_t num_groups, + int64_t first_dim, int64_t last_dim) { + auto grouped = GroupedTensorWrapper( + num_groups, std::vector{num_groups * static_cast(first_dim), + static_cast(last_dim)}); + grouped.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), + tensor_shape_1d(data)); + return grouped; +} + +struct GroupedWeightArg { + bool is_grouped = false; + at::Tensor packed; + std::vector discrete; + // Logical per-expert weight shape. For both supported layouts: + // - packed single grouped weight: packed has shape [G, rows, cols] + // - discrete weights: each tensor has shape [rows, cols] + // rows = out_features, cols = in_features. + int64_t rows = 0; + int64_t cols = 0; + + c10::Device device() const { return is_grouped ? packed.device() : discrete[0].device(); } +}; + +GroupedWeightArg weight_arg_from_py(py::handle arg, size_t num_groups, at::ScalarType dtype, + const std::string &name) { + GroupedWeightArg out; + if (py::isinstance(arg) || py::isinstance(arg)) { + auto seq = py::reinterpret_borrow(arg); + NVTE_CHECK(static_cast(seq.size()) == num_groups, name, " must have ", num_groups, + " tensors."); + out.discrete.reserve(num_groups); + for (size_t i = 0; i < num_groups; ++i) { + auto tensor = maybe_cast_dtype(seq[i].cast(), dtype); + check_contiguous(tensor, name); + NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2."); + if (i == 0) { + // Discrete case: each expert owns one [out_features, in_features] + // tensor. Cache the shared logical shape for later GEMM setup. + out.rows = tensor.size(0); + out.cols = tensor.size(1); + } else { + NVTE_CHECK(tensor.size(0) == out.rows && tensor.size(1) == out.cols, name, + " tensors must have a uniform shape."); + } + out.discrete.emplace_back(tensor); + } + return out; + } + + out.packed = maybe_cast_dtype(arg.cast(), dtype); + NVTE_CHECK(out.packed.dim() == 3, name, " must be a tensor with shape [num_groups, rows, cols]."); + NVTE_CHECK(static_cast(out.packed.size(0)) == num_groups, name, + " first dimension must be ", num_groups, "."); + check_contiguous(out.packed, name); + out.is_grouped = true; + // Packed case: a single [G, out_features, in_features] tensor stores all + // experts, so dimensions 1 and 2 are the same per-expert logical shape. + out.rows = out.packed.size(1); + out.cols = out.packed.size(2); + return out; +} + +at::Tensor packed_bias_from_arg(py::handle arg, size_t num_groups, at::ScalarType dtype, + int64_t out_features, const std::string &name) { + if (is_none(arg)) { + return at::Tensor(); + } + + auto packed = maybe_cast_dtype(arg.cast(), dtype); + NVTE_CHECK(packed.dim() == 2, name, " must be a tensor with shape [num_groups, features]."); + NVTE_CHECK(static_cast(packed.size(0)) == num_groups, name, " first dimension must be ", + num_groups, "."); + NVTE_CHECK(packed.size(1) == out_features, name, " second dimension must be ", out_features, "."); + check_contiguous(packed, name); + return packed; +} + +std::vector nvte_tensor_list_from_tensors(const std::vector &tensors, + std::vector *wrappers) { + wrappers->clear(); + wrappers->reserve(tensors.size()); + std::vector out; + out.reserve(tensors.size()); + for (const auto &tensor : tensors) { + wrappers->emplace_back(makeTransformerEngineTensor(tensor)); + out.emplace_back(wrappers->back().data()); + } + return out; +} + +int grouped_gemm_math_sm_count(const c10::Device &device) { + const int device_id = static_cast(device.index()); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + return sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); +} + +std::array grouped_gemm_scratch_from_arg(py::handle scratch, + const c10::Device &device, + size_t num_groups) { + const int64_t num_groups_i64 = static_cast(num_groups); + const int64_t setup_size = + static_cast(nvte_get_grouped_gemm_setup_workspace_size(num_groups)); + + if (is_none(scratch)) { + return { + at::ones({num_groups_i64}, at::device(device).dtype(at::kFloat)), + at::zeros({num_groups_i64}, at::device(device).dtype(at::kFloat)), + at::empty({setup_size}, at::device(device).dtype(at::kByte)), + at::empty({kGroupedGemmCublasWorkspaceSize}, at::device(device).dtype(at::kByte)), + }; + } + + NVTE_CHECK(py::isinstance(scratch) || py::isinstance(scratch), + "megacpp grouped MLP GEMM scratch must be None or a 4-tensor tuple/list."); + auto seq = py::reinterpret_borrow(scratch); + NVTE_CHECK(seq.size() == 4, "megacpp grouped MLP GEMM scratch must have 4 tensors."); + + std::array tensors = { + seq[0].cast(), + seq[1].cast(), + seq[2].cast(), + seq[3].cast(), + }; + return tensors; +} + +struct GroupedGemmResources { + c10::Device device; + size_t num_groups; + at::Tensor alpha; + at::Tensor beta_zero; + at::Tensor beta_one; + at::Tensor setup; + at::Tensor cublas; + TensorWrapper te_alpha; + TensorWrapper te_beta_zero; + TensorWrapper te_beta_one; + TensorWrapper te_setup; + TensorWrapper te_cublas; + std::optional config; + + GroupedGemmResources(const c10::Device &device_, size_t num_groups_, + std::array scratch) + : device(device_), + num_groups(num_groups_), + alpha(std::move(scratch[0])), + beta_zero(std::move(scratch[1])), + beta_one(alpha), + setup(std::move(scratch[2])), + cublas(std::move(scratch[3])), + te_alpha(makeTransformerEngineTensor(alpha)), + te_beta_zero(makeTransformerEngineTensor(beta_zero)), + te_beta_one(makeTransformerEngineTensor(beta_one)), + te_setup(makeTransformerEngineTensor( + setup.data_ptr(), std::vector{static_cast(setup.numel())}, + DType::kByte)), + te_cublas(makeTransformerEngineTensor( + cublas.data_ptr(), std::vector{static_cast(cublas.numel())}, + DType::kByte)) { + // These scratch tensors may be cached by Python per CUDA stream. Every + // current megacpp grouped GEMM below is enqueued on at::cuda::getCurrentCUDAStream(), + // so same-stream ordering protects workspace reuse. If a future backend + // uses auxiliary streams, cache keys or stream recording must be revisited. + const int math_sm_count = grouped_gemm_math_sm_count(device); + if (math_sm_count > 0) { + config.emplace(); + config->set_sm_count(math_sm_count); + } + } + + NVTETensor beta(bool accumulate) { return accumulate ? te_beta_one.data() : te_beta_zero.data(); } + + NVTEGroupedMatmulConfig config_data() { + return config.has_value() ? static_cast(*config) : nullptr; + } +}; + +GroupedGemmResources make_grouped_mlp_backend_resources(const c10::Device &device, + size_t num_groups, py::handle scratch) { + // Keep the backend resource policy private to megacpp. Today this is cuBLAS + // grouped GEMM scratch; future backends can change this helper without + // changing the Python or pybind contract. + return GroupedGemmResources(device, num_groups, + grouped_gemm_scratch_from_arg(scratch, device, num_groups)); +} + +void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B, bool transb, + GroupedTensorWrapper *D, GroupedGemmResources *resources, bool accumulate) { + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm(A->data(), transa, B->data(), transb, D->data(), D->data(), + resources->te_alpha.data(), resources->beta(accumulate), + resources->te_setup.data(), resources->te_cublas.data(), + resources->config_data(), at::cuda::getCurrentCUDAStream()); + }); +} + +std::vector output_tensor_list_from_arg(py::handle arg, size_t num_groups, int64_t rows, + int64_t cols, const std::string &name) { + std::vector out; + if (is_none(arg)) { + return out; + } + out.reserve(num_groups); + // This helper is intentionally only for the discrete-weight external wgrad + // path, where Megatron provides one main_grad tensor per expert. The packed + // [G, rows, cols] external buffer used by single grouped weight is handled in + // wgrad_output_from_arg so it can stay packed and use grouped-tensor GEMM. + NVTE_CHECK(py::isinstance(arg) || py::isinstance(arg), name, + " must be a list or tuple of wgrad output tensors."); + auto seq = py::reinterpret_borrow(arg); + NVTE_CHECK(static_cast(seq.size()) == num_groups, name, " must have ", num_groups, + " tensors."); + for (size_t i = 0; i < num_groups; ++i) { + auto tensor = seq[i].cast(); + NVTE_CHECK(tensor.is_cuda(), name, " tensors must be CUDA tensors."); + // Do not require tensor.scalar_type() == dtype. Caller-owned + // main_grad buffers are allocated by Megatron and may be FP32 even when TE + // grouped MLP compute is BF16. + NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2 wgrad buffers."); + NVTE_CHECK(tensor.size(0) == rows && tensor.size(1) == cols, name, + " tensors must have shape [rows, cols]."); + check_contiguous(tensor, name); + out.emplace_back(tensor); + } + return out; +} + +struct WgradOutput { + std::vector tensors; + at::Tensor packed; + bool is_grouped = false; + bool owns_storage = false; +}; + +WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num_groups, + at::ScalarType dtype, const c10::Device &device, int64_t rows, + int64_t cols, const std::string &name, + bool prefer_grouped_output) { + WgradOutput out; + if (!compute_wgrad) { + return out; + } + if (is_none(arg)) { + // Cases 1 and 2: no external wgrad buffer was provided, so C++ owns the + // allocation. Single grouped weight keeps this packed as [G, N, K]; + // discrete weights split the same packed allocation into per-expert views. + out.packed = + at::empty({static_cast(num_groups), rows, cols}, at::device(device).dtype(dtype)); + out.owns_storage = true; + out.is_grouped = prefer_grouped_output; + if (out.is_grouped) { + return out; + } + out.tensors.reserve(num_groups); + for (size_t i = 0; i < num_groups; ++i) { + out.tensors.emplace_back(out.packed.select(0, static_cast(i))); + } + return out; + } + if (!py::isinstance(arg) && !py::isinstance(arg)) { + // Case 3: single grouped weight with externally-owned storage, e.g. + // Megatron main_grad viewed as [G, N, K]. GEMM writes in-place and Python + // should not receive a newly allocated grad tensor from this helper. + out.packed = arg.cast(); + NVTE_CHECK(out.packed.is_cuda(), name, " must be a CUDA tensor."); + // Do not require out.packed.scalar_type() == dtype. Caller-owned + // main_grad buffers keep the precision chosen by Megatron's grad-buffer config. + NVTE_CHECK(out.packed.dim() == 3, name, " must have shape [num_groups, rows, cols]."); + NVTE_CHECK(static_cast(out.packed.size(0)) == num_groups, name, + " first dimension must be ", num_groups, "."); + NVTE_CHECK(out.packed.size(1) == rows && out.packed.size(2) == cols, name, + " has an unexpected shape."); + check_contiguous(out.packed, name); + out.is_grouped = true; + return out; + } + // Case 4: discrete weights with externally-owned per-expert buffers, e.g. + // Megatron main_grad list. GEMM writes each tensor in-place and returns no + // allocated grad list to Python. + out.tensors = output_tensor_list_from_arg(arg, num_groups, rows, cols, name); + return out; +} + +void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight, + GroupedTensorWrapper *input, bool trans_input, + GroupedTensorWrapper *output, GroupedGemmResources *resources) { + if (weights->is_grouped) { + // Single grouped weight case: weights are packed as [G, N, K]. Wrap the + // packed buffer as a uniform GroupedTensor and use the grouped-tensor GEMM. + auto grouped_weight = make_uniform_grouped_tensor(weights->packed, input->num_tensors(), + weights->rows, weights->cols); + grouped_gemm(&grouped_weight, trans_weight, input, trans_input, output, resources, false); + } else { + // Discrete weight case: weights are a list of per-expert tensors. Use the + // discrete-input grouped GEMM variant. + std::vector weight_wrappers; + auto weight_nvte = nvte_tensor_list_from_tensors(weights->discrete, &weight_wrappers); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm_with_discrete_inputA( + weight_nvte.data(), weights->discrete.size(), trans_weight, input->data(), trans_input, + output->data(), output->data(), resources->te_alpha.data(), resources->beta(false), + resources->te_setup.data(), resources->te_cublas.data(), resources->config_data(), + at::cuda::getCurrentCUDAStream()); + }); + } +} + +std::vector grouped_gemm_wgrad(GroupedTensorWrapper *x, GroupedTensorWrapper *dy, + py::handle output, bool compute_wgrad, bool accumulate, + GroupedGemmResources *resources, at::ScalarType dtype, + int64_t rows, int64_t cols, const std::string &name, + bool prefer_grouped_output) { + auto prepared = wgrad_output_from_arg(output, compute_wgrad, resources->num_groups, dtype, + resources->device, rows, cols, name, prefer_grouped_output); + NVTE_CHECK(!(prepared.owns_storage && accumulate), name, + " cannot accumulate into a newly allocated wgrad buffer."); + std::vector returned_wgrads; + + if (prepared.is_grouped) { + // Cases 1 and 3: single grouped weight layout. + // Case 1: C++ allocated packed [G, N, K] storage; return [packed]. + // Case 3: caller provided packed storage, e.g. main_grad; write in-place + // and return nothing because autograd receives dummy wgrad tensors. + auto grouped_output = + make_uniform_grouped_tensor(prepared.packed, resources->num_groups, rows, cols); + grouped_gemm(x, false, dy, true, &grouped_output, resources, accumulate); + if (prepared.owns_storage) { + returned_wgrads.emplace_back(prepared.packed); + } + } else if (!prepared.tensors.empty()) { + // Cases 2 and 4: discrete per-expert weight layout. + // Case 2: C++ allocated packed backing storage and split it into views; + // return those views in parameter order. + // Case 4: caller provided per-expert buffers, e.g. main_grad list; write + // in-place and return nothing because autograd receives dummy wgrads. + std::vector output_wrappers; + auto output_nvte = nvte_tensor_list_from_tensors(prepared.tensors, &output_wrappers); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm_with_discrete_out( + x->data(), false, dy->data(), true, output_nvte.data(), resources->num_groups, + output_nvte.data(), resources->num_groups, resources->te_alpha.data(), + resources->beta(accumulate), resources->te_setup.data(), resources->te_cublas.data(), + resources->config_data(), at::cuda::getCurrentCUDAStream()); + }); + if (prepared.owns_storage) { + returned_wgrads = prepared.tensors; + } + } + return returned_wgrads; +} + +GroupedTensorWrapper make_grouped_bias(const at::Tensor &bias, size_t num_groups, + at::ScalarType bias_dtype, int64_t out_features) { + NVTE_CHECK(bias.defined(), "Bias tensor must be defined."); + auto grouped = GroupedTensorWrapper( + num_groups, std::vector{num_groups, static_cast(out_features)}); + grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(bias_dtype), + tensor_shape_1d(bias)); + return grouped; +} + +void add_grouped_bias(GroupedTensorWrapper *output, const at::Tensor &bias, size_t num_groups, + at::ScalarType dtype, int64_t out_features, + std::optional bias_scale = std::nullopt) { + if (!bias.defined()) { + return; + } + auto grouped_bias = make_grouped_bias(bias, num_groups, dtype, out_features); + if (bias_scale.has_value()) { + auto scale = maybe_cast_dtype(*bias_scale, at::kFloat); + check_contiguous(scale, "bias_scale"); + scale = scale.view({-1}); + auto te_scale = makeTransformerEngineTensor(scale); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_scaled_bias_add(output->data(), grouped_bias.data(), te_scale.data(), + at::cuda::getCurrentCUDAStream()); + }); + } else { + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_bias_add(output->data(), grouped_bias.data(), at::cuda::getCurrentCUDAStream()); + }); + } +} + +bool is_gated_activation(const std::string &activation) { + return activation == "swiglu" || activation == "clamped_swiglu" || activation == "geglu" || + activation == "reglu" || activation == "qgeglu" || activation == "sreglu"; +} + +at::Tensor maybe_deinterleave_glu(const at::Tensor &input, int64_t glu_interleave_size) { + if (glu_interleave_size <= 0) { + return input; + } + auto shape = input.sizes().vec(); + const int64_t last_dim = shape.back(); + NVTE_CHECK(last_dim % (2 * glu_interleave_size) == 0, + "GLU interleaving requires the last dimension to be divisible by 2*interleave."); + check_contiguous(input, "GLU input"); + // Explicit layout materialization: GLU interleave changes memory order. + return input.view({-1, last_dim / (2 * glu_interleave_size), 2, glu_interleave_size}) + .transpose(1, 2) + .contiguous() + .view(shape); +} + +at::Tensor maybe_reinterleave_glu_grad(const at::Tensor &input, int64_t glu_interleave_size) { + if (glu_interleave_size <= 0) { + return input; + } + auto shape = input.sizes().vec(); + const int64_t last_dim = shape.back(); + check_contiguous(input, "GLU grad input"); + // Explicit layout materialization: reverse GLU interleave changes memory order. + return input.view({-1, 2, last_dim / (2 * glu_interleave_size), glu_interleave_size}) + .transpose(1, 2) + .contiguous() + .view(shape); +} + +at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &activation, + double activation_limit, double activation_alpha, + double activation_glu_linear_offset) { + const int64_t out_features = + is_gated_activation(activation) ? input.size(-1) / 2 : input.size(-1); + auto output = at::empty({input.size(0), out_features}, input.options()); + auto te_input = makeTransformerEngineTensor(input); + auto te_output = makeTransformerEngineTensor(output); + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + if (activation == "swiglu") { + nvte_swiglu(te_input.data(), te_output.data(), stream); + } else if (activation == "glu") { + nvte_glu(te_input.data(), te_output.data(), stream); + } else if (activation == "geglu") { + nvte_geglu(te_input.data(), te_output.data(), stream); + } else if (activation == "qgeglu") { + nvte_qgeglu(te_input.data(), te_output.data(), stream); + } else if (activation == "reglu") { + nvte_reglu(te_input.data(), te_output.data(), stream); + } else if (activation == "sreglu") { + nvte_sreglu(te_input.data(), te_output.data(), stream); + } else if (activation == "clamped_swiglu") { + nvte_clamped_swiglu_v2(te_input.data(), te_output.data(), + static_cast(activation_limit), + static_cast(activation_alpha), + static_cast(activation_glu_linear_offset), stream); + } else if (activation == "srelu") { + nvte_srelu(te_input.data(), te_output.data(), stream); + } else if (activation == "gelu") { + nvte_gelu(te_input.data(), te_output.data(), stream); + } else if (activation == "qgelu") { + nvte_qgelu(te_input.data(), te_output.data(), stream); + } else if (activation == "relu") { + nvte_relu(te_input.data(), te_output.data(), stream); + } else if (activation == "silu") { + nvte_silu(te_input.data(), te_output.data(), stream); + } else { + NVTE_ERROR("Unsupported megacpp grouped MLP activation: ", activation); + } + }); + return output; +} + +at::Tensor activation_backward_impl(const at::Tensor &grad, const at::Tensor &input, + const std::string &activation, double activation_limit, + double activation_alpha, double activation_glu_linear_offset) { + auto output = at::empty_like(input); + auto te_grad = makeTransformerEngineTensor(grad); + auto te_input = makeTransformerEngineTensor(input); + auto te_output = makeTransformerEngineTensor(output); + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + if (activation == "swiglu") { + nvte_dswiglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "glu") { + nvte_dglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "geglu") { + nvte_dgeglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "qgeglu") { + nvte_dqgeglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "reglu") { + nvte_dreglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "sreglu") { + nvte_dsreglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "clamped_swiglu") { + nvte_clamped_dswiglu_v2(te_grad.data(), te_input.data(), te_output.data(), + static_cast(activation_limit), + static_cast(activation_alpha), + static_cast(activation_glu_linear_offset), stream); + } else if (activation == "srelu") { + nvte_dsrelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "gelu") { + nvte_dgelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "qgelu") { + nvte_dqgelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "relu") { + nvte_drelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "silu") { + nvte_dsilu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else { + NVTE_ERROR("Unsupported megacpp grouped MLP activation backward: ", activation); + } + }); + return output; +} + +at::Tensor grouped_mlp_activation_forward( + const at::Tensor &input, const std::optional &act_scales, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, at::ScalarType dtype) { + auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); + auto activation_output = activation_forward_impl(activation_input, activation, activation_limit, + activation_alpha, activation_glu_linear_offset); + if (!act_scales.has_value()) { + return activation_output; + } + auto act_scales_for_fc2 = maybe_cast_dtype(*act_scales, dtype); + check_contiguous(act_scales_for_fc2, "act_scales"); + return activation_output * act_scales_for_fc2.view({-1, 1}); +} + +struct ActivationBackwardResult { + at::Tensor grad_input; + at::Tensor grad_act_scales; +}; + +ActivationBackwardResult grouped_mlp_activation_backward( + const at::Tensor &grad_output, const at::Tensor &input, + const std::optional &act_scales, const std::string &activation, + int64_t glu_interleave_size, double activation_limit, double activation_alpha, + double activation_glu_linear_offset, at::ScalarType dtype, bool act_scales_requires_grad) { + auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); + + at::Tensor grad_activation_output = grad_output; + at::Tensor grad_act_scales; + if (act_scales.has_value()) { + if (act_scales_requires_grad) { + // Scaled activations compute y = activation(x) * act_scales[:, None]. + // Recompute activation(x) for dact_scales to match the Python basic-op + // path without saving another [tokens, hidden] activation tensor. + auto activation_output = + activation_forward_impl(activation_input, activation, activation_limit, activation_alpha, + activation_glu_linear_offset); + grad_act_scales = (activation_output * grad_output).sum(-1); + } + auto act_scales_for_grad = maybe_cast_dtype(*act_scales, dtype); + check_contiguous(act_scales_for_grad, "act_scales"); + grad_activation_output = grad_output * act_scales_for_grad.view({-1, 1}); + } + + auto grad_activation_input = + activation_backward_impl(grad_activation_output, activation_input, activation, + activation_limit, activation_alpha, activation_glu_linear_offset); + return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size), grad_act_scales}; +} + +} // namespace + +std::vector megacpp_grouped_mlp_forward( + const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes, + py::handle fc1_weight, py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, + const std::optional &act_scales, const std::string &activation, + int64_t glu_interleave_size, double activation_limit, double activation_alpha, + double activation_glu_linear_offset, py::handle gemm_scratch) { + NVTE_CHECK(input.is_cuda(), "megacpp_grouped_mlp_forward requires CUDA input."); + at::cuda::CUDAGuard device_guard(input.device()); + + // act_dtype is the requested activation/GEMM input dtype. The incoming + // tensor may have a different dtype, so canonicalize it once at the API + // boundary and use this tensor for all downstream grouped GEMMs. + const auto dtype = act_dtype; + auto x = maybe_cast_dtype(input, dtype); + check_contiguous(x, "input"); + + const auto num_groups = static_cast(split_sizes.numel()); + NVTE_CHECK(num_groups > 0, "megacpp grouped MLP requires at least one group."); + + NVTE_CHECK(dtype == at::kBFloat16 || dtype == at::kHalf, + "megacpp grouped MLP currently supports BF16/FP16 only."); + + auto fc1_weights = weight_arg_from_py(fc1_weight, num_groups, dtype, "fc1_weight"); + auto fc2_weights = weight_arg_from_py(fc2_weight, num_groups, dtype, "fc2_weight"); + const int64_t in_features = fc1_weights.cols; + const int64_t fc1_out_features = fc1_weights.rows; + const int64_t fc2_out_features = fc2_weights.rows; + const int64_t fc2_in_features = fc2_weights.cols; + const int64_t activation_out_features = + is_gated_activation(activation) ? fc1_out_features / 2 : fc1_out_features; + NVTE_CHECK(activation_out_features == fc2_in_features, + "FC1 activation output dimension must match FC2 input dimension."); + auto fc1_bias_tensor = + packed_bias_from_arg(fc1_bias, num_groups, dtype, fc1_out_features, "fc1_bias"); + auto fc2_bias_tensor = + packed_bias_from_arg(fc2_bias, num_groups, dtype, fc2_out_features, "fc2_bias"); + + NVTE_CHECK(x.numel() % in_features == 0, "input last dimension is incompatible with FC1."); + const int64_t total_tokens = x.numel() / in_features; + auto [split_sizes_i64, split_offsets] = splits_to_offsets_multi( + split_sizes, x.device(), + std::vector{1, in_features, fc1_out_features, fc2_in_features, fc2_out_features}, + std::vector{true, true, true, true, true}, + std::vector{at::kLong, at::kLong, at::kLong, at::kLong, at::kLong}, true); + // splits_to_offsets_multi returns the canonical int64 CUDA split sizes and + // offsets in the same order as the stride list above. The CuTe path also asks + // for int32 split_points, but cuBLAS grouped GEMM does not consume them. + NVTE_CHECK(split_offsets.size() == 5, "Expected five grouped split-offset tensors."); + auto base_offsets = split_offsets[0]; + auto x_offsets = split_offsets[1]; + auto fc1_offsets = split_offsets[2]; + auto fc2_offsets = split_offsets[3]; + auto output_offsets = split_offsets[4]; + auto gemm_resources = make_grouped_mlp_backend_resources(x.device(), num_groups, gemm_scratch); + + auto fc1_preact = at::empty({total_tokens, fc1_out_features}, x.options()); + auto grouped_x = make_grouped_tensor(x, split_sizes_i64, x_offsets, in_features); + auto grouped_fc1_preact = + make_grouped_tensor(fc1_preact, split_sizes_i64, fc1_offsets, fc1_out_features); + grouped_gemm_fwd_dgrad(&fc1_weights, true, &grouped_x, false, &grouped_fc1_preact, + &gemm_resources); + add_grouped_bias(&grouped_fc1_preact, fc1_bias_tensor, num_groups, dtype, fc1_out_features); + + auto fc2_x = grouped_mlp_activation_forward( + fc1_preact, act_scales, activation, glu_interleave_size, activation_limit, activation_alpha, + activation_glu_linear_offset, dtype); + + std::vector out_shape = input.sizes().vec(); + out_shape.back() = fc2_out_features; + auto output = at::empty(out_shape, x.options()); + auto grouped_fc2_x = make_grouped_tensor(fc2_x, split_sizes_i64, fc2_offsets, fc2_in_features); + auto grouped_output = + make_grouped_tensor(output, split_sizes_i64, output_offsets, fc2_out_features); + grouped_gemm_fwd_dgrad(&fc2_weights, true, &grouped_fc2_x, false, &grouped_output, + &gemm_resources); + add_grouped_bias(&grouped_output, fc2_bias_tensor, num_groups, dtype, fc2_out_features); + + return {output, x, split_sizes_i64, base_offsets, x_offsets, + fc1_offsets, fc2_offsets, output_offsets, fc1_preact, fc2_x}; +} + +py::tuple megacpp_grouped_mlp_backward( + const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes, + const at::Tensor &x_offsets, const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, + const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, const at::Tensor &x, + const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, + const std::optional &act_scales, py::handle fc1_weight, py::handle fc2_weight, + py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, + py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, + bool input_requires_grad, py::handle gemm_scratch) { + (void)base_offsets; + NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output."); + at::cuda::CUDAGuard device_guard(grad_output.device()); + + // act_dtype is the requested grouped-MLP compute dtype. Backward receives + // autograd's grad_output as-is, so canonicalize it here instead of requiring + // a Python-side aten::to before entering C++. + const auto dtype = act_dtype; + auto dy = maybe_cast_dtype(grad_output, dtype); + check_contiguous(dy, "grad_output"); + + const auto num_groups = num_groups_from_prepared_split_sizes(split_sizes, grad_output.device()); + auto fc1_weights = weight_arg_from_py(fc1_weight, num_groups, dtype, "fc1_weight"); + auto fc2_weights = weight_arg_from_py(fc2_weight, num_groups, dtype, "fc2_weight"); + + const int64_t in_features = fc1_weights.cols; + const int64_t fc1_out_features = fc1_weights.rows; + const int64_t fc2_out_features = fc2_weights.rows; + const int64_t fc2_in_features = fc2_weights.cols; + + NVTE_CHECK(dy.numel() % fc2_out_features == 0, + "grad_output last dimension is incompatible with FC2."); + const int64_t total_tokens = dy.numel() / fc2_out_features; + auto gemm_resources = + make_grouped_mlp_backend_resources(grad_output.device(), num_groups, gemm_scratch); + + auto grouped_dy = make_grouped_tensor(dy, split_sizes, fc2_dy_offsets, fc2_out_features); + std::vector fc2_wgrads; + if (fc2_compute_wgrad) { + auto fc2_x_for_wgrad = maybe_cast_dtype(fc2_x, dtype); + check_contiguous(fc2_x_for_wgrad, "fc2_x"); + auto grouped_fc2_x_for_wgrad = + make_grouped_tensor(fc2_x_for_wgrad, split_sizes, fc2_offsets, fc2_in_features); + fc2_wgrads = grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output, + fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype, + fc2_out_features, fc2_in_features, "fc2_wgrad_output", + fc2_weights.is_grouped); + } + + auto fc2_dx = at::empty({total_tokens, fc2_in_features}, dy.options()); + auto grouped_fc2_dx = make_grouped_tensor(fc2_dx, split_sizes, fc2_offsets, fc2_in_features); + grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, &gemm_resources); + + auto activation_grads = grouped_mlp_activation_backward( + fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, activation_limit, + activation_alpha, activation_glu_linear_offset, dtype, act_scales_requires_grad); + auto fc1_dy = activation_grads.grad_input; + auto grad_act_scales = activation_grads.grad_act_scales; + auto grouped_fc1_dy = make_grouped_tensor(fc1_dy, split_sizes, fc1_offsets, fc1_out_features); + + std::vector fc1_wgrads; + if (fc1_compute_wgrad) { + auto x_for_wgrad = maybe_cast_dtype(x, dtype); + check_contiguous(x_for_wgrad, "x"); + auto grouped_x_for_wgrad = + make_grouped_tensor(x_for_wgrad, split_sizes, x_offsets, in_features); + fc1_wgrads = grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output, + fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype, + fc1_out_features, in_features, "fc1_wgrad_output", + fc1_weights.is_grouped); + } + + at::Tensor grad_input; + if (input_requires_grad) { + std::vector grad_input_shape = grad_output.sizes().vec(); + grad_input_shape.back() = in_features; + grad_input = at::empty(grad_input_shape, dy.options()); + auto grouped_grad_input = make_grouped_tensor(grad_input, split_sizes, x_offsets, in_features); + grouped_gemm_fwd_dgrad(&fc1_weights, false, &grouped_fc1_dy, false, &grouped_grad_input, + &gemm_resources); + } else { + grad_input = at::empty({0}, dy.options()); + } + + auto empty_return = at::empty({0}, dy.options()); + if (!grad_act_scales.defined()) { + grad_act_scales = empty_return; + } + return py::make_tuple(grad_input, fc1_dy, grad_act_scales, fc1_wgrads, fc2_wgrads); +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 78f9d880ba..fd09162ade 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -39,3 +39,9 @@ BackwardGroupedMLP_CuTeGEMMDGLU, BackwardGroupedMLP_CuTeGEMMDUnary, ) +from .forward_grouped_mlp_megacpp import ( # pylint: disable=wrong-import-position + ForwardGroupedMLP_MegaCpp, +) +from .backward_grouped_mlp_megacpp import ( # pylint: disable=wrong-import-position + BackwardGroupedMLP_MegaCpp, +) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py new file mode 100644 index 0000000000..3899b5c6ee --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py @@ -0,0 +1,394 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mega C++ grouped MLP backward fuser.""" + +from __future__ import annotations +import functools +from typing import Optional + +import torch + +import transformer_engine_torch as tex +from ...quantization import Recipe +from ...utils import clear_tensor_data, get_device_compute_capability +from ...triton.grouped_dbias_dscales import compute_grouped_dbias +from ..basic import GroupedLinear +from ..fuser import register_backward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import ( + get_accumulate_flag_in_param, + get_dummy_wgrads_for_params, + get_main_grad_from_param, + view_main_grad_as_grouped_buffer, +) +from .forward_grouped_mlp_megacpp import ( + _grouped_gemm_scratch, + _megacpp_activation_config, + _megacpp_enabled, + _megacpp_supports_recipe, + _resolve_megacpp_grouped_mlp_config, +) + + +def _megacpp_saved_weight_arg( + saved_tensors: tuple[torch.Tensor, ...], + *, + single_weight_arg: bool, + num_groups: int, +) -> tuple[torch.Tensor | list[torch.Tensor], tuple[torch.Tensor, ...]]: + """Unpack saved C++ weight argument in the same shape used by forward.""" + if single_weight_arg: + return saved_tensors[0], saved_tensors[1:] + return list(saved_tensors[:num_groups]), saved_tensors[num_groups:] + + +def _delay_wgrad(fc_op: GroupedLinear, ctx: OperationContext) -> bool: + """Whether this FC op requested unsupported delayed wgrad.""" + return bool( + ctx.weight_requires_grad + and fc_op.wgrad_store is not None + and fc_op.wgrad_store.delay_wgrad_compute() + ) + + +def _compute_bias_grad_params( + fc_op: GroupedLinear, + dy_2d: torch.Tensor, + base_offsets: torch.Tensor, + *, + num_groups: int, + dtype: torch.dtype, +) -> tuple[Optional[list[torch.Tensor]], Optional[torch.Tensor]]: + """Compute bias grads in GroupedLinear parameter layout.""" + if not fc_op.has_bias: + return None, None + dbias_packed = compute_grouped_dbias(dy_2d, base_offsets, num_groups).to(dtype=dtype) + if fc_op.single_grouped_bias: + return None, dbias_packed + return [dbias_packed[idx] for idx in range(num_groups)], None + + +def _prepare_cpp_wgrad_output( + fc_op: GroupedLinear, + ctx: OperationContext, + *, + num_groups: int, + weight_shape: tuple[int, int], + label: str, +) -> tuple[Optional[torch.Tensor | list[torch.Tensor]], bool, bool, list[Optional[torch.Tensor]]]: + """Return an optional externally-owned wgrad buffer for C++. + + If Megatron has already installed ``main_grad`` buffers, C++ writes into + them. Otherwise this returns ``None`` and C++ allocates/returns a packed + ``[num_groups, out_features, in_features]`` wgrad tensor. + """ + weights = fc_op._get_weight_tensors() + weight_grads: list[Optional[torch.Tensor]] = ( + [None] if fc_op.single_grouped_weight else [None] * num_groups + ) + if _delay_wgrad(fc_op, ctx): + raise ValueError("megacpp grouped MLP does not support delay_wgrad_compute=True.") + if not ctx.weight_requires_grad: + return None, False, False, weight_grads + + accumulate_into_main_grad = False + if fc_op.single_grouped_weight: + if fc_op._accumulate_into_main_grad: + main_grad = get_main_grad_from_param(weights[0], op_label=label) + wgrad_output = view_main_grad_as_grouped_buffer( + main_grad, + num_groups, + weight_shape, + label=f"{label} weight", + ) + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + weight_grads = get_dummy_wgrads_for_params(weights) + else: + wgrad_output = None + else: + if fc_op._accumulate_into_main_grad: + wgrad_output = [get_main_grad_from_param(w, op_label=label) for w in weights] + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + weight_grads = get_dummy_wgrads_for_params(weights) + else: + wgrad_output = None + + return wgrad_output, True, accumulate_into_main_grad, weight_grads + + +def _assemble_grad_params( + fc_op: GroupedLinear, + weight_grads: list[Optional[torch.Tensor]], + bias_grads: Optional[list[torch.Tensor]], + bias_grad_packed: Optional[torch.Tensor], + *, + num_groups: int, +) -> list[Optional[torch.Tensor]]: + """Assemble parameter grads in GroupedLinear registration order.""" + if not fc_op.has_bias: + return weight_grads + if fc_op.single_grouped_bias: + return weight_grads + [bias_grad_packed] + bias_list = bias_grads if bias_grads is not None else [None] * num_groups + if fc_op.single_grouped_weight: + return bias_list + weight_grads + return weight_grads + bias_list + + +class BackwardGroupedMLP_MegaCpp(FusedOperation): + """Experimental C++ grouped MLP backward for BF16/FP16. + + Weight gradients are computed in C++. Delayed wgrad is intentionally not + supported in this first implementation to keep ownership and lifetime rules + simple. + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + if not torch.cuda.is_available(): + return False + if get_device_compute_capability()[0] < 10: + return False + return hasattr(tex, "megacpp_grouped_mlp_backward") + + def __init__( + self, + *, + fc1: GroupedLinear, + activation: Optional[FusibleOperation], + fc2: GroupedLinear, + ) -> None: + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) + _resolve_megacpp_grouped_mlp_config(fc1, activation, fc2) + if fc1._scale_bias or fc2._scale_bias: + raise RuntimeError("megacpp grouped MLP does not support scale_bias yet.") + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + **unused, # pylint: disable=unused-argument + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + fc1_op, activation_op, fc2_op = self.basic_ops + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs + num_groups = fc1_op.num_groups + dtype = fc1_ctx.dtype + + fc1_saved = fc1_ctx.saved_tensors + split_sizes, base_offsets, x_offsets, fc1_offsets = fc1_saved[:4] + x, fc1_activation_input = fc1_saved[4:6] + fc1_weight_arg, _ = _megacpp_saved_weight_arg( + fc1_saved[6:], + single_weight_arg=bool(getattr(fc1_ctx, "single_weight_arg", False)), + num_groups=num_groups, + ) + + activation_config = _megacpp_activation_config(activation_op) + _, act_scales = activation_ctx.saved_tensors + + fc2_saved = fc2_ctx.saved_tensors + fc2_offsets = fc2_saved[2] + fc2_dy_offsets = fc2_saved[3] + fc2_x = fc2_saved[4] + fc2_weight_arg, _ = _megacpp_saved_weight_arg( + fc2_saved[5:], + single_weight_arg=bool(getattr(fc2_ctx, "single_weight_arg", False)), + num_groups=num_groups, + ) + + ( + fc1_wgrad_output, + fc1_compute_wgrad, + fc1_accumulate_wgrad, + fc1_weight_grads, + ) = _prepare_cpp_wgrad_output( + fc1_op, + fc1_ctx, + num_groups=num_groups, + weight_shape=(fc1_op.out_features, fc1_op.in_features), + label="Grouped MLP megacpp backward (FC1)", + ) + ( + fc2_wgrad_output, + fc2_compute_wgrad, + fc2_accumulate_wgrad, + fc2_weight_grads, + ) = _prepare_cpp_wgrad_output( + fc2_op, + fc2_ctx, + num_groups=num_groups, + weight_shape=(fc2_op.out_features, fc2_op.in_features), + label="Grouped MLP megacpp backward (FC2)", + ) + ( + grad_input, + fc1_dy, + grad_act_scales, + fc1_owned_weight_grads, + fc2_owned_weight_grads, + ) = tex.megacpp_grouped_mlp_backward( + grad_output, + dtype, + split_sizes, + x_offsets, + fc1_offsets, + fc2_offsets, + fc2_dy_offsets, + base_offsets, + x, + fc1_activation_input, + fc2_x, + act_scales, + fc1_weight_arg, + fc2_weight_arg, + fc1_wgrad_output, + fc1_compute_wgrad, + fc1_accumulate_wgrad, + fc2_wgrad_output, + fc2_compute_wgrad, + fc2_accumulate_wgrad, + activation_config.name, + activation_config.glu_interleave_size, + activation_config.limit, + activation_config.alpha, + activation_config.glu_linear_offset, + bool(activation_ctx.extra_input_requires_grad), + bool(fc1_ctx.input_requires_grad), + _grouped_gemm_scratch(num_groups, grad_output.device), + ) + if not fc1_ctx.input_requires_grad: + grad_input = None + + grad_output_2d = grad_output.reshape(-1, fc2_op.out_features).to(dtype=dtype) + fc2_bias_grads, fc2_bias_grad_packed = _compute_bias_grad_params( + fc2_op, + grad_output_2d, + base_offsets, + num_groups=num_groups, + dtype=dtype, + ) + fc1_bias_grads, fc1_bias_grad_packed = _compute_bias_grad_params( + fc1_op, + fc1_dy, + base_offsets, + num_groups=num_groups, + dtype=dtype, + ) + + # Wgrad ownership cases: + # 1. No weight grad: keep [None] placeholders prepared above. + # 2. Megatron-owned main_grad: C++ wrote into the provided buffer; + # keep dummy wgrads prepared above for autograd. + # 3. C++-owned allocation: replace the placeholder list with returned + # wgrads. Single grouped weight returns [packed], discrete weights + # return one tensor per expert. + if fc2_ctx.weight_requires_grad and not fc2_op._accumulate_into_main_grad: + expected_wgrads = 1 if fc2_op.single_grouped_weight else num_groups + if len(fc2_owned_weight_grads) != expected_wgrads: + raise RuntimeError(f"FC2 expected {expected_wgrads} owned wgrad tensors.") + fc2_weight_grads = fc2_owned_weight_grads + fc2_grad_params = _assemble_grad_params( + fc2_op, + fc2_weight_grads, + fc2_bias_grads, + fc2_bias_grad_packed, + num_groups=num_groups, + ) + clear_tensor_data(fc2_x) + + # Same ownership policy as FC2. Megatron-owned main_grad keeps the + # prepared dummy grads; C++-owned allocation uses the returned wgrads. + if fc1_ctx.weight_requires_grad and not fc1_op._accumulate_into_main_grad: + expected_wgrads = 1 if fc1_op.single_grouped_weight else num_groups + if len(fc1_owned_weight_grads) != expected_wgrads: + raise RuntimeError(f"FC1 expected {expected_wgrads} owned wgrad tensors.") + fc1_weight_grads = fc1_owned_weight_grads + fc1_grad_params = _assemble_grad_params( + fc1_op, + fc1_weight_grads, + fc1_bias_grads, + fc1_bias_grad_packed, + num_groups=num_groups, + ) + clear_tensor_data(x) + + # d(act_scales) belongs to the extra input, so match act_scales.dtype + activation_grad_extra = ( + (grad_act_scales.to(dtype=act_scales.dtype),) + if activation_ctx.extra_input_requires_grad + else (None,) + ) + + return ( + grad_input, + [fc1_grad_params, (), fc2_grad_params], + [ + (None,), + activation_grad_extra, + (None,), + ], + ) + + +def fuse_backward_megacpp_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply opt-in C++ grouped MLP backward fusion for BF16/FP16.""" + if not _megacpp_enabled(): + return ops + if not _megacpp_supports_recipe(recipe): + return ops + if not BackwardGroupedMLP_MegaCpp.is_supported(): + return ops + + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + matches_pattern = True + if not (isinstance(window[0], GroupedLinear) and isinstance(window[2], GroupedLinear)): + matches_pattern = False + elif window[0]._scale_bias or window[2]._scale_bias: + matches_pattern = False + else: + try: + _resolve_megacpp_grouped_mlp_config(window[0], window[1], window[2]) + except (TypeError, ValueError, RuntimeError): + matches_pattern = False + + if matches_pattern: + window = [ + BackwardGroupedMLP_MegaCpp( + fc1=window[0], + activation=window[1], + fc2=window[2], + ) + ] + else: + out.extend(window[:-2]) + window = window[-2:] + + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + out.extend(window) + return out + + +# Use the same opt-in and recipe gate as forward. Unsupported recipes fall +# through unchanged so the matching recipe-specific backward fuser can run. +register_backward_fusion(fuse_backward_megacpp_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py new file mode 100644 index 0000000000..0b086aa56b --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py @@ -0,0 +1,421 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mega C++ grouped MLP forward fuser.""" + +from __future__ import annotations +from collections.abc import Iterable +import functools +import os +from typing import Any, NamedTuple, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpp_extensions.gemm import ( + get_cublas_workspace_size_bytes, + get_grouped_gemm_setup_workspace_size, +) +from ...quantization import Recipe +from ...tensor import Quantizer +from ...utils import get_device_compute_capability +from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU, ScaledSwiGLU +from ..fuser import register_forward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext + + +def _megacpp_enabled() -> bool: + """Whether the experimental grouped MLP C++ path is explicitly enabled.""" + return int(os.getenv("NVTE_MEGACPP_GROUPED_LINEAR", "0")) > 0 + + +def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool: + """Whether megacpp is a valid candidate for the active quantization recipe. + + Today the C++ implementation is BF16/FP16-only, so only the no-recipe path + is supported. Returning False for FP8 recipes is intentional: it leaves the + op list unchanged so the existing MXFP8/NVFP4 CuTe DSL fusers can match. + Future MXFP8/NVFP4 support should be enabled by changing this predicate, + not by reordering fusion registrations. + """ + if recipe is None: + return True + if recipe.mxfp8() or recipe.nvfp4(): + return False + return False + + +@functools.lru_cache(maxsize=None) +def _cached_grouped_gemm_scratch( + num_groups: int, + device_index: int, + _stream_handle: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Cached cuBLAS grouped GEMM scratch for one CUDA stream. + + ``_stream_handle`` is intentionally part of the cache key. The workspace is + reused without recording extra streams, so it must not be shared by + concurrent streams. + """ + device = torch.device("cuda", device_index) + with torch.cuda.device(device): + setup_size = get_grouped_gemm_setup_workspace_size(num_groups) + cublas_size = get_cublas_workspace_size_bytes() + return ( + torch.ones(num_groups, dtype=torch.float32, device=device), + torch.zeros(num_groups, dtype=torch.float32, device=device), + torch.empty(setup_size, dtype=torch.uint8, device=device), + torch.empty(cublas_size, dtype=torch.uint8, device=device), + ) + + +def _grouped_gemm_scratch( + num_groups: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Return cached GEMM resources for the current stream on ``device``.""" + device_index = torch.cuda.current_device() if device.index is None else device.index + stream_handle = int(torch.cuda.current_stream(device_index).cuda_stream) + return _cached_grouped_gemm_scratch(num_groups, device_index, stream_handle) + + +class _MegaCppActivationConfig(NamedTuple): + """Activation semantics consumed by the C++ grouped MLP path.""" + + name: str + is_scaled: bool + is_gated: bool + glu_interleave_size: int + limit: float = 0.0 + alpha: float = 0.0 + glu_linear_offset: float = 0.0 + + +def _megacpp_activation_config(activation) -> _MegaCppActivationConfig: + """Return activation parameters consumed by the C++ grouped MLP path.""" + glu_interleave_size = int(getattr(activation, "glu_interleave_size", None) or 0) + if isinstance(activation, ScaledSwiGLU): + return _MegaCppActivationConfig("swiglu", True, True, glu_interleave_size) + if isinstance(activation, ScaledClampedQGeGLU): + return _MegaCppActivationConfig( + "clamped_swiglu", + True, + True, + glu_interleave_size, + float(activation._clamped.limit), + float(activation._clamped.alpha), + float(activation._clamped.glu_linear_offset), + ) + if isinstance(activation, ScaledSReLU): + return _MegaCppActivationConfig("srelu", True, False, 0) + if getattr(activation, "num_extra_inputs", 0) == 0: + return _MegaCppActivationConfig("plain_unsupported", False, False, 0) + raise TypeError( + "megacpp grouped MLP currently supports only ScaledSwiGLU, " + "ScaledClampedQGeGLU, and ScaledSReLU." + ) + + +def _resolve_megacpp_grouped_mlp_config( + fc1: GroupedLinear, + activation, + fc2: GroupedLinear, +) -> _MegaCppActivationConfig: + """Resolve megacpp activation config and validate grouped MLP support.""" + config = _megacpp_activation_config(activation) + if not config.is_scaled: + raise RuntimeError( + "megacpp grouped MLP keeps an optional-scale activation API, but plain " + f"{activation.__class__.__name__} is not supported yet." + ) + if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0: + raise ValueError( + f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, " + f"in_features={fc1.in_features}, out_features={fc1.out_features})." + ) + if fc2.in_features % 64 != 0 or fc2.out_features % 64 != 0: + raise ValueError( + f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " + f"in_features={fc2.in_features}, out_features={fc2.out_features})." + ) + expected_fc1_out_features = 2 * fc2.in_features if config.is_gated else fc2.in_features + if fc1.out_features != expected_fc1_out_features or fc1.num_groups != fc2.num_groups: + raise ValueError( + f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " + f"out_features={fc1.out_features}) " + f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " + f"out_features={fc2.out_features}) do not match." + ) + if config.glu_interleave_size and fc1.out_features % (2 * config.glu_interleave_size) != 0: + raise ValueError( + "GLU interleaving requires FC1 out_features to be divisible by " + f"2*glu_interleave_size, got out_features={fc1.out_features}, " + f"glu_interleave_size={config.glu_interleave_size}." + ) + return config + + +def _megacpp_weight_arg( + linear_op: GroupedLinear, + dtype: torch.dtype, + *, + input_requires_grad: bool, +) -> torch.Tensor | list[torch.Tensor]: + """Return GEMM-ready high-precision weights for the current C++ path. + + Keep the layout policy in GroupedLinear. This handles quantized weights the + same way as the Python grouped GEMM path: BF16/FP16 compute dequantizes when + needed, while a future quantized-compute path can preserve quantized weights + by switching ``with_quantized_compute``. + """ + with_quantized_compute = False + if linear_op.single_grouped_weight: + grouped_weight = linear_op._get_grouped_weight_for_gemm( + linear_op.weight, + [linear_op.get_quantizer("forward", 1)], + columnwise_usage=input_requires_grad, + with_quantized_compute=with_quantized_compute, + dtype=dtype, + ) + if grouped_weight.rowwise_data is None: + raise RuntimeError("megacpp grouped MLP expected dense grouped weight rowwise_data.") + # Keep single grouped weight packed. The C++ path wraps this as a + # uniform GroupedTensor and dispatches nvte_grouped_gemm instead of + # expanding it into per-expert discrete tensors. + return grouped_weight.rowwise_data.view( + linear_op.num_groups, + linear_op.out_features, + linear_op.in_features, + ) + return linear_op._get_discrete_weights_for_gemm( + [getattr(linear_op, f"weight{idx}") for idx in range(linear_op.num_groups)], + [linear_op.get_quantizer("forward", 2 * idx + 1) for idx in range(linear_op.num_groups)], + columnwise_usage=input_requires_grad, + with_quantized_compute=with_quantized_compute, + dtype=dtype, + ) + + +def _megacpp_bias_arg(linear_op: GroupedLinear, dtype: torch.dtype) -> Optional[torch.Tensor]: + """Return a packed [G, N] high-precision bias tensor or None.""" + grouped_bias = linear_op._get_grouped_bias_for_gemm(dtype) + if grouped_bias is None: + return None + return grouped_bias.rowwise_data.view(linear_op.num_groups, linear_op.out_features) + + +class ForwardGroupedMLP_MegaCpp(FusedOperation): + """Experimental BF16/FP16 grouped MLP forward implemented in C++. + + The C++ function returns plain tensors only. Python still owns autograd + context layout; delayed wgrad is rejected by the matching backward op. + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether the C++ grouped MLP path can be dispatched.""" + if not torch.cuda.is_available(): + return False + if get_device_compute_capability()[0] < 10: + return False + return hasattr(tex, "megacpp_grouped_mlp_forward") + + def __init__( + self, + *, + fc1: GroupedLinear, + activation: Optional[FusibleOperation], + fc2: GroupedLinear, + ) -> None: + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) + _resolve_megacpp_grouped_mlp_config(fc1, activation, fc2) + if fc1._scale_bias or fc2._scale_bias: + raise RuntimeError("megacpp grouped MLP does not support scale_bias yet.") + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + del prev_op_grad_output_quantizer, next_op_input_quantizer, basic_op_kwargs + fc1_op, activation_op, fc2_op = self.basic_ops + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs + num_groups = fc1_op.num_groups + + split_sizes = basic_op_extra_inputs[0][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] + if ( + split_sizes.size() != fc2_split_sizes.size() + or split_sizes.data_ptr() != fc2_split_sizes.data_ptr() + ): + raise RuntimeError(f"{self.__class__.__name__} got different split sizes for FC1/FC2.") + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, got {int(split_sizes.numel())}.") + + activation_config = _megacpp_activation_config(activation_op) + act_scales = basic_op_extra_inputs[1][0] + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + dtype = ( + torch.get_autocast_dtype("cuda") + if torch.is_autocast_enabled() + else fc1_weight_param.dtype + ) + if dtype not in (torch.bfloat16, torch.float16): + raise RuntimeError(f"megacpp grouped MLP supports BF16/FP16 only, got {dtype}.") + + requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) + input_requires_grad = requires_grad + fc1_weight_requires_grad = requires_grad and fc1_weight_param.requires_grad + fc2_weight_requires_grad = requires_grad and fc2_weight_param.requires_grad + + fc1_weights = _megacpp_weight_arg( + fc1_op, + dtype, + input_requires_grad=input_requires_grad, + ) + fc2_weights = _megacpp_weight_arg( + fc2_op, + dtype, + input_requires_grad=input_requires_grad, + ) + gemm_scratch = _grouped_gemm_scratch(num_groups, input_.device) + ( + fc2_out, + x, + split_sizes_i64, + base_split_offsets, + x_offsets, + fc1_offsets, + fc2_offsets, + fc2_dy_offsets, + fc1_activation_input, + fc2_x, + ) = tex.megacpp_grouped_mlp_forward( + input_, + dtype, + split_sizes, + fc1_weights, + _megacpp_bias_arg(fc1_op, dtype), + fc2_weights, + _megacpp_bias_arg(fc2_op, dtype), + act_scales, + activation_config.name, + activation_config.glu_interleave_size, + activation_config.limit, + activation_config.alpha, + activation_config.glu_linear_offset, + gemm_scratch, + ) + + if x.data_ptr() == input_.data_ptr(): + x._do_not_clear = True + + if requires_grad: + fc1_saved_weights = ( + [fc1_weights] if isinstance(fc1_weights, torch.Tensor) else fc1_weights + ) + fc2_saved_weights = ( + [fc2_weights] if isinstance(fc2_weights, torch.Tensor) else fc2_weights + ) + + fc1_ctx.save_for_backward( + split_sizes_i64, + base_split_offsets, + x_offsets, + fc1_offsets, + x, + fc1_activation_input, + *fc1_saved_weights, + ) + fc1_ctx.use_megacpp_grouped_mlp = True + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = fc1_weight_requires_grad + fc1_ctx.single_weight_arg = isinstance(fc1_weights, torch.Tensor) + + activation_ctx.save_for_backward(fc1_activation_input, act_scales) + activation_ctx.extra_input_requires_grad = act_scales.requires_grad + activation_ctx.input_requires_grad = True + activation_ctx.dtype = dtype + + fc2_ctx.save_for_backward( + split_sizes_i64, + base_split_offsets, + fc2_offsets, + fc2_dy_offsets, + fc2_x, + *fc2_saved_weights, + ) + fc2_ctx.use_megacpp_grouped_mlp = True + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = fc2_weight_requires_grad + fc2_ctx.single_weight_arg = isinstance(fc2_weights, torch.Tensor) + + return fc2_out, [(), (), ()] + + +def fuse_forward_megacpp_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply opt-in C++ grouped MLP fusion for BF16/FP16.""" + if not _megacpp_enabled(): + return ops + if not _megacpp_supports_recipe(recipe): + return ops + if not ForwardGroupedMLP_MegaCpp.is_supported(): + return ops + + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + matches_pattern = True + if not (isinstance(window[0], GroupedLinear) and isinstance(window[2], GroupedLinear)): + matches_pattern = False + elif window[0]._scale_bias or window[2]._scale_bias: + matches_pattern = False + else: + try: + _resolve_megacpp_grouped_mlp_config(window[0], window[1], window[2]) + except (TypeError, ValueError, RuntimeError): + matches_pattern = False + + if matches_pattern: + window = [ + ForwardGroupedMLP_MegaCpp( + fc1=window[0], + activation=window[1], + fc2=window[2], + ) + ] + else: + out.extend(window[:-2]) + window = window[-2:] + + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + out.extend(window) + return out + + +# Explicit env opt-in gives megacpp first chance. Unsupported recipes intentionally +# return the ops unchanged so lower-priority recipe-specific fusers remain the +# fallback path. +register_forward_fusion(fuse_forward_megacpp_ops, prepend=True) From 07b28363d19ada73c4476967bc502c17ad409bfe Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 16 Jun 2026 00:31:51 -0700 Subject: [PATCH 8/8] integrate fused scaled swiglu and srelu Signed-off-by: Zhongbo Zhu --- tests/pytorch/megacpp/test_grouped_mlp.py | 76 +++++- .../pytorch/csrc/megacpp/grouped_mlp.cpp | 249 ++++++++---------- 2 files changed, 176 insertions(+), 149 deletions(-) diff --git a/tests/pytorch/megacpp/test_grouped_mlp.py b/tests/pytorch/megacpp/test_grouped_mlp.py index b056af1978..267f926f09 100644 --- a/tests/pytorch/megacpp/test_grouped_mlp.py +++ b/tests/pytorch/megacpp/test_grouped_mlp.py @@ -31,6 +31,72 @@ def _megacpp_available() -> tuple[bool, str]: pytestmark = pytest.mark.skipif(not _AVAILABLE, reason=_SKIP_REASON) +class _Fp32ScaledActivationMixin: + """Run the wrapped scaled-activation op in fp32 (test-only). + + The scaled-activation ops choose their compute dtype from the input dtype, + so upcasting the activation input (and the per-row scales) to fp32 makes the + GLU/SReLU and the scaling run in fp32 with a single cast back to the output + dtype -- matching the fused megacpp activation kernel. Outputs and gradients + are cast back to the original dtype so the surrounding grouped GEMMs see the + same dtype as the unwrapped op. This keeps the production ops unchanged and + only exercises their existing input-dtype-driven fp32 path in the reference. + """ + + def fuser_forward( + self, + basic_op_ctxs, + input_, + *, + basic_op_extra_inputs, + prev_op_grad_output_quantizer, + next_op_input_quantizer, + basic_op_kwargs, + ): + # The op picks its compute dtype from the main input, then upcasts the + # scales internally via maybe_dequantize. So only ``input_`` needs to be + # fp32; the extra inputs (act_scales) are passed through untouched to + # preserve their requires_grad (casting them here, inside the fuser's + # no-grad forward, would drop the grad and skip the scale gradient). + out, extra_outputs = super().fuser_forward( + basic_op_ctxs, + input_.float(), + basic_op_extra_inputs=basic_op_extra_inputs, + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, + next_op_input_quantizer=next_op_input_quantizer, + basic_op_kwargs=basic_op_kwargs, + ) + return out.to(input_.dtype), extra_outputs + + def fuser_backward(self, basic_op_ctxs, grad_output, *, basic_op_grad_extra_outputs): + # Returns: (grad_input, grad_params, grad_extra_inputs). The act_scales + # gradient lives in grad_extra_inputs (the third element). + grad_input, grad_params, grad_extra_inputs = super().fuser_backward( + basic_op_ctxs, + grad_output, + basic_op_grad_extra_outputs=basic_op_grad_extra_outputs, + ) + if grad_input is not None: + grad_input = grad_input.to(grad_output.dtype) + grad_extra_inputs = [ + tuple(None if g is None else g.to(grad_output.dtype) for g in group) + for group in grad_extra_inputs + ] + return grad_input, grad_params, grad_extra_inputs + + +class _Fp32ScaledSwiGLU(_Fp32ScaledActivationMixin, te_ops.ScaledSwiGLU): + pass + + +class _Fp32ScaledClampedQGeGLU(_Fp32ScaledActivationMixin, te_ops.ScaledClampedQGeGLU): + pass + + +class _Fp32ScaledSReLU(_Fp32ScaledActivationMixin, te_ops.ScaledSReLU): + pass + + def _make_grouped_mlp( *, num_groups: int, @@ -57,12 +123,16 @@ def _make_grouped_mlp( single_grouped_weight=single_grouped_param, single_grouped_bias=single_grouped_param and bias, ) + # Use the fp32-compute wrappers so the (non-fused) reference path computes + # the activation in fp32, matching the fused megacpp kernel. The wrappers + # subclass the real ops, so the megacpp fusion still recognizes them via + # isinstance and the fused path ignores the wrapper entirely. if activation_kind == "scaled_swiglu": - act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + act = _Fp32ScaledSwiGLU(glu_interleave_size=glu_interleave_size) elif activation_kind == "scaled_clamped_qgeglu": - act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + act = _Fp32ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) elif activation_kind == "scaled_srelu": - act = te_ops.ScaledSReLU() + act = _Fp32ScaledSReLU() else: raise ValueError(f"Unsupported test activation_kind={activation_kind}.") fc2 = te_ops.GroupedLinear( diff --git a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp index 4292adb349..63709cc92f 100644 --- a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp +++ b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp @@ -462,143 +462,76 @@ void add_grouped_bias(GroupedTensorWrapper *output, const at::Tensor &bias, size } } -bool is_gated_activation(const std::string &activation) { - return activation == "swiglu" || activation == "clamped_swiglu" || activation == "geglu" || - activation == "reglu" || activation == "qgeglu" || activation == "sreglu"; -} +enum class MegacppActivation { + kSwiGLU, + kClampedSwiGLU, + kSReLU, +}; -at::Tensor maybe_deinterleave_glu(const at::Tensor &input, int64_t glu_interleave_size) { - if (glu_interleave_size <= 0) { - return input; +MegacppActivation activation_from_string(const std::string &activation) { + if (activation == "swiglu") { + return MegacppActivation::kSwiGLU; + } + if (activation == "clamped_swiglu") { + return MegacppActivation::kClampedSwiGLU; + } + if (activation == "srelu") { + return MegacppActivation::kSReLU; } - auto shape = input.sizes().vec(); - const int64_t last_dim = shape.back(); - NVTE_CHECK(last_dim % (2 * glu_interleave_size) == 0, - "GLU interleaving requires the last dimension to be divisible by 2*interleave."); - check_contiguous(input, "GLU input"); - // Explicit layout materialization: GLU interleave changes memory order. - return input.view({-1, last_dim / (2 * glu_interleave_size), 2, glu_interleave_size}) - .transpose(1, 2) - .contiguous() - .view(shape); + NVTE_ERROR("Unsupported megacpp grouped MLP scaled activation: ", activation); + return MegacppActivation::kSwiGLU; } -at::Tensor maybe_reinterleave_glu_grad(const at::Tensor &input, int64_t glu_interleave_size) { - if (glu_interleave_size <= 0) { - return input; - } - auto shape = input.sizes().vec(); - const int64_t last_dim = shape.back(); - check_contiguous(input, "GLU grad input"); - // Explicit layout materialization: reverse GLU interleave changes memory order. - return input.view({-1, 2, last_dim / (2 * glu_interleave_size), glu_interleave_size}) - .transpose(1, 2) - .contiguous() - .view(shape); +bool is_gated_activation(MegacppActivation activation) { + return activation == MegacppActivation::kSwiGLU || + activation == MegacppActivation::kClampedSwiGLU; } -at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &activation, - double activation_limit, double activation_alpha, - double activation_glu_linear_offset) { +int64_t activation_output_features(const at::Tensor &input, MegacppActivation activation) { const int64_t out_features = is_gated_activation(activation) ? input.size(-1) / 2 : input.size(-1); - auto output = at::empty({input.size(0), out_features}, input.options()); - auto te_input = makeTransformerEngineTensor(input); - auto te_output = makeTransformerEngineTensor(output); - auto stream = at::cuda::getCurrentCUDAStream(); - NVTE_SCOPED_GIL_RELEASE({ - if (activation == "swiglu") { - nvte_swiglu(te_input.data(), te_output.data(), stream); - } else if (activation == "glu") { - nvte_glu(te_input.data(), te_output.data(), stream); - } else if (activation == "geglu") { - nvte_geglu(te_input.data(), te_output.data(), stream); - } else if (activation == "qgeglu") { - nvte_qgeglu(te_input.data(), te_output.data(), stream); - } else if (activation == "reglu") { - nvte_reglu(te_input.data(), te_output.data(), stream); - } else if (activation == "sreglu") { - nvte_sreglu(te_input.data(), te_output.data(), stream); - } else if (activation == "clamped_swiglu") { - nvte_clamped_swiglu_v2(te_input.data(), te_output.data(), - static_cast(activation_limit), - static_cast(activation_alpha), - static_cast(activation_glu_linear_offset), stream); - } else if (activation == "srelu") { - nvte_srelu(te_input.data(), te_output.data(), stream); - } else if (activation == "gelu") { - nvte_gelu(te_input.data(), te_output.data(), stream); - } else if (activation == "qgelu") { - nvte_qgelu(te_input.data(), te_output.data(), stream); - } else if (activation == "relu") { - nvte_relu(te_input.data(), te_output.data(), stream); - } else if (activation == "silu") { - nvte_silu(te_input.data(), te_output.data(), stream); - } else { - NVTE_ERROR("Unsupported megacpp grouped MLP activation: ", activation); - } - }); - return output; + NVTE_CHECK(out_features > 0, "Activation output dimension must be non-zero."); + return out_features; } -at::Tensor activation_backward_impl(const at::Tensor &grad, const at::Tensor &input, - const std::string &activation, double activation_limit, - double activation_alpha, double activation_glu_linear_offset) { - auto output = at::empty_like(input); - auto te_grad = makeTransformerEngineTensor(grad); +at::Tensor grouped_mlp_activation_forward(const at::Tensor &input, + const std::optional &act_scales, + MegacppActivation activation, int64_t glu_interleave_size, + double activation_limit, double activation_alpha, + double activation_glu_linear_offset) { + const int64_t out_features = activation_output_features(input, activation); + NVTE_CHECK(act_scales.has_value(), "megacpp grouped MLP scaled activation requires act_scales."); + const at::Tensor &scales = *act_scales; + NVTE_CHECK(scales.is_cuda(), "act_scales must be a CUDA tensor."); + NVTE_CHECK(scales.device() == input.device(), + "act_scales must be on the same device as activation."); + NVTE_CHECK(scales.numel() == input.size(0), "act_scales must have one value per activation row."); + check_contiguous(scales, "act_scales"); + auto output = at::empty({input.size(0), out_features}, input.options()); auto te_input = makeTransformerEngineTensor(input); + auto te_scales = makeTransformerEngineTensor(scales); auto te_output = makeTransformerEngineTensor(output); auto stream = at::cuda::getCurrentCUDAStream(); NVTE_SCOPED_GIL_RELEASE({ - if (activation == "swiglu") { - nvte_dswiglu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "glu") { - nvte_dglu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "geglu") { - nvte_dgeglu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "qgeglu") { - nvte_dqgeglu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "reglu") { - nvte_dreglu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "sreglu") { - nvte_dsreglu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "clamped_swiglu") { - nvte_clamped_dswiglu_v2(te_grad.data(), te_input.data(), te_output.data(), - static_cast(activation_limit), - static_cast(activation_alpha), - static_cast(activation_glu_linear_offset), stream); - } else if (activation == "srelu") { - nvte_dsrelu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "gelu") { - nvte_dgelu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "qgelu") { - nvte_dqgelu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "relu") { - nvte_drelu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else if (activation == "silu") { - nvte_dsilu(te_grad.data(), te_input.data(), te_output.data(), stream); - } else { - NVTE_ERROR("Unsupported megacpp grouped MLP activation backward: ", activation); + switch (activation) { + case MegacppActivation::kSwiGLU: + nvte_scaled_swiglu(te_input.data(), te_scales.data(), te_output.data(), glu_interleave_size, + stream); + break; + case MegacppActivation::kClampedSwiGLU: + nvte_scaled_clamped_swiglu( + te_input.data(), te_scales.data(), te_output.data(), + static_cast(activation_limit), static_cast(activation_alpha), + static_cast(activation_glu_linear_offset), glu_interleave_size, stream); + break; + case MegacppActivation::kSReLU: + nvte_scaled_srelu(te_input.data(), te_scales.data(), te_output.data(), stream); + break; } }); return output; } -at::Tensor grouped_mlp_activation_forward( - const at::Tensor &input, const std::optional &act_scales, - const std::string &activation, int64_t glu_interleave_size, double activation_limit, - double activation_alpha, double activation_glu_linear_offset, at::ScalarType dtype) { - auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); - auto activation_output = activation_forward_impl(activation_input, activation, activation_limit, - activation_alpha, activation_glu_linear_offset); - if (!act_scales.has_value()) { - return activation_output; - } - auto act_scales_for_fc2 = maybe_cast_dtype(*act_scales, dtype); - check_contiguous(act_scales_for_fc2, "act_scales"); - return activation_output * act_scales_for_fc2.view({-1, 1}); -} - struct ActivationBackwardResult { at::Tensor grad_input; at::Tensor grad_act_scales; @@ -606,32 +539,54 @@ struct ActivationBackwardResult { ActivationBackwardResult grouped_mlp_activation_backward( const at::Tensor &grad_output, const at::Tensor &input, - const std::optional &act_scales, const std::string &activation, + const std::optional &act_scales, MegacppActivation activation, int64_t glu_interleave_size, double activation_limit, double activation_alpha, - double activation_glu_linear_offset, at::ScalarType dtype, bool act_scales_requires_grad) { - auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); - - at::Tensor grad_activation_output = grad_output; + double activation_glu_linear_offset, bool act_scales_requires_grad) { + NVTE_CHECK(act_scales.has_value(), "megacpp grouped MLP scaled activation requires act_scales."); + const at::Tensor &scales = *act_scales; + NVTE_CHECK(scales.is_cuda(), "act_scales must be a CUDA tensor."); + NVTE_CHECK(scales.device() == input.device(), + "act_scales must be on the same device as activation."); + NVTE_CHECK(scales.numel() == input.size(0), "act_scales must have one value per activation row."); + check_contiguous(scales, "act_scales"); + auto grad_input = at::empty_like(input); at::Tensor grad_act_scales; - if (act_scales.has_value()) { - if (act_scales_requires_grad) { - // Scaled activations compute y = activation(x) * act_scales[:, None]. - // Recompute activation(x) for dact_scales to match the Python basic-op - // path without saving another [tokens, hidden] activation tensor. - auto activation_output = - activation_forward_impl(activation_input, activation, activation_limit, activation_alpha, - activation_glu_linear_offset); - grad_act_scales = (activation_output * grad_output).sum(-1); - } - auto act_scales_for_grad = maybe_cast_dtype(*act_scales, dtype); - check_contiguous(act_scales_for_grad, "act_scales"); - grad_activation_output = grad_output * act_scales_for_grad.view({-1, 1}); + if (act_scales_requires_grad) { + grad_act_scales = at::empty({input.size(0)}, grad_output.options()); } - auto grad_activation_input = - activation_backward_impl(grad_activation_output, activation_input, activation, - activation_limit, activation_alpha, activation_glu_linear_offset); - return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size), grad_act_scales}; + auto te_grad_output = makeTransformerEngineTensor(grad_output); + auto te_input = makeTransformerEngineTensor(input); + auto te_scales = makeTransformerEngineTensor(scales); + auto te_grad_input = makeTransformerEngineTensor(grad_input); + std::optional te_grad_act_scales; + if (grad_act_scales.defined()) { + te_grad_act_scales.emplace(makeTransformerEngineTensor(grad_act_scales)); + } + NVTETensor te_grad_act_scales_ptr = + grad_act_scales.defined() ? te_grad_act_scales->data() : nullptr; + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + switch (activation) { + case MegacppActivation::kSwiGLU: + nvte_scaled_dswiglu(te_grad_output.data(), te_input.data(), te_scales.data(), + te_grad_input.data(), te_grad_act_scales_ptr, glu_interleave_size, + stream); + break; + case MegacppActivation::kClampedSwiGLU: + nvte_scaled_clamped_dswiglu( + te_grad_output.data(), te_input.data(), te_scales.data(), te_grad_input.data(), + te_grad_act_scales_ptr, static_cast(activation_limit), + static_cast(activation_alpha), static_cast(activation_glu_linear_offset), + glu_interleave_size, stream); + break; + case MegacppActivation::kSReLU: + nvte_scaled_dsrelu(te_grad_output.data(), te_input.data(), te_scales.data(), + te_grad_input.data(), te_grad_act_scales_ptr, stream); + break; + } + }); + return {grad_input, grad_act_scales}; } } // namespace @@ -644,6 +599,7 @@ std::vector megacpp_grouped_mlp_forward( double activation_glu_linear_offset, py::handle gemm_scratch) { NVTE_CHECK(input.is_cuda(), "megacpp_grouped_mlp_forward requires CUDA input."); at::cuda::CUDAGuard device_guard(input.device()); + const auto activation_kind = activation_from_string(activation); // act_dtype is the requested activation/GEMM input dtype. The incoming // tensor may have a different dtype, so canonicalize it once at the API @@ -665,7 +621,7 @@ std::vector megacpp_grouped_mlp_forward( const int64_t fc2_out_features = fc2_weights.rows; const int64_t fc2_in_features = fc2_weights.cols; const int64_t activation_out_features = - is_gated_activation(activation) ? fc1_out_features / 2 : fc1_out_features; + is_gated_activation(activation_kind) ? fc1_out_features / 2 : fc1_out_features; NVTE_CHECK(activation_out_features == fc2_in_features, "FC1 activation output dimension must match FC2 input dimension."); auto fc1_bias_tensor = @@ -699,9 +655,9 @@ std::vector megacpp_grouped_mlp_forward( &gemm_resources); add_grouped_bias(&grouped_fc1_preact, fc1_bias_tensor, num_groups, dtype, fc1_out_features); - auto fc2_x = grouped_mlp_activation_forward( - fc1_preact, act_scales, activation, glu_interleave_size, activation_limit, activation_alpha, - activation_glu_linear_offset, dtype); + auto fc2_x = grouped_mlp_activation_forward(fc1_preact, act_scales, activation_kind, + glu_interleave_size, activation_limit, + activation_alpha, activation_glu_linear_offset); std::vector out_shape = input.sizes().vec(); out_shape.back() = fc2_out_features; @@ -731,6 +687,7 @@ py::tuple megacpp_grouped_mlp_backward( (void)base_offsets; NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output."); at::cuda::CUDAGuard device_guard(grad_output.device()); + const auto activation_kind = activation_from_string(activation); // act_dtype is the requested grouped-MLP compute dtype. Backward receives // autograd's grad_output as-is, so canonicalize it here instead of requiring @@ -772,8 +729,8 @@ py::tuple megacpp_grouped_mlp_backward( grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, &gemm_resources); auto activation_grads = grouped_mlp_activation_backward( - fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, activation_limit, - activation_alpha, activation_glu_linear_offset, dtype, act_scales_requires_grad); + fc2_dx, fc1_activation_input, act_scales, activation_kind, glu_interleave_size, + activation_limit, activation_alpha, activation_glu_linear_offset, act_scales_requires_grad); auto fc1_dy = activation_grads.grad_input; auto grad_act_scales = activation_grads.grad_act_scales; auto grouped_fc1_dy = make_grouped_tensor(fc1_dy, split_sizes, fc1_offsets, fc1_out_features);