-
Notifications
You must be signed in to change notification settings - Fork 695
Scaled Bias Add support after CUBLAS GGEMM #2885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ba6f5eb
cb0504d
2a344b2
9f98357
c565367
b64559a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |||||||||
|
|
||||||||||
| #include <algorithm> | ||||||||||
| #include <cstdint> | ||||||||||
| #include <type_traits> | ||||||||||
| #include <vector> | ||||||||||
|
|
||||||||||
| #include "../common.h" | ||||||||||
|
|
@@ -845,43 +846,102 @@ __forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorSha | |||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // Kernel that performs bias addition to the Grouped GEMM output tensors. | ||||||||||
| // Bias itself is a grouped tensor with the collections of same number of tensors | ||||||||||
| // as the output tensors. | ||||||||||
| template <typename T, int kVec> | ||||||||||
| __global__ void grouped_bias_add_kernel(char *d_base, const char *bias_base, TensorShapeInfo d_meta, | ||||||||||
| TensorShapeInfo bias_meta, size_t num_tensors) { | ||||||||||
| const size_t tensor_idx = blockIdx.x; | ||||||||||
| if (tensor_idx >= num_tensors) return; | ||||||||||
| // Kernel that performs (optionally scaled) bias addition to Grouped GEMM output tensors. | ||||||||||
| // 2D grid: blockIdx.x = row chunk, blockIdx.y = column chunk. | ||||||||||
| // Each block loads bias once for its column chunk and sweeps its rows | ||||||||||
| // with direct vectorized load-add-store on d. | ||||||||||
| template <typename T, int kVec, bool UseScale, int kBlockDim, int kRowsPerBlock> | ||||||||||
| __global__ void grouped_bias_add_kernel(char *__restrict__ d_base, | ||||||||||
| const char *__restrict__ bias_base, | ||||||||||
| const float *__restrict__ scale_base, | ||||||||||
| TensorShapeInfo d_meta, int n, int total_rows, | ||||||||||
| int num_tensors) { | ||||||||||
| using VecStorage = transformer_engine::VectorizedStorage<T, kVec>; | ||||||||||
| using VecType = typename VecStorage::LType; | ||||||||||
|
|
||||||||||
| const int64_t m = d_meta.first_dims ? d_meta.first_dims[tensor_idx] : d_meta.uniform_first; | ||||||||||
| const int64_t n = d_meta.last_dims ? d_meta.last_dims[tensor_idx] : d_meta.uniform_last; | ||||||||||
| constexpr int kMaxTensors = 257; | ||||||||||
| __shared__ int cumsum[kMaxTensors]; | ||||||||||
|
Comment on lines
+862
to
+863
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable name is wrong.
Suggested change
|
||||||||||
|
|
||||||||||
| const int64_t d_offset = compute_grouped_tensor_offset(d_meta, tensor_idx); | ||||||||||
| const int64_t bias_offset = compute_grouped_tensor_offset(bias_meta, tensor_idx); | ||||||||||
| const int tid = static_cast<int>(threadIdx.x); | ||||||||||
| const int block_dim = static_cast<int>(blockDim.x); | ||||||||||
| const int row_bid = static_cast<int>(blockIdx.x); | ||||||||||
| const int col_bid = static_cast<int>(blockIdx.y); | ||||||||||
|
|
||||||||||
| auto *d_ptr = reinterpret_cast<T *>(d_base + d_offset * sizeof(T)); | ||||||||||
| const auto *bias_ptr = reinterpret_cast<const T *>(bias_base + bias_offset * sizeof(T)); | ||||||||||
| const int row_start = row_bid * kRowsPerBlock; | ||||||||||
| const int row_end = min(row_start + kRowsPerBlock, total_rows); | ||||||||||
| if (row_start >= total_rows) return; | ||||||||||
|
|
||||||||||
| const int block_cols = block_dim * kVec; | ||||||||||
| const int col = col_bid * block_cols + tid * kVec; | ||||||||||
| if (col >= n) return; | ||||||||||
|
|
||||||||||
| // Build cumulative row prefix-sum in shared memory. | ||||||||||
| if (tid == 0) cumsum[0] = 0; | ||||||||||
| for (int i = tid; i < num_tensors; i += block_dim) { | ||||||||||
| cumsum[i + 1] = | ||||||||||
| static_cast<int>(d_meta.first_dims ? d_meta.first_dims[i] : d_meta.uniform_first); | ||||||||||
| } | ||||||||||
| __syncthreads(); | ||||||||||
| if (tid == 0) { | ||||||||||
| for (int t = 1; t <= num_tensors; t++) cumsum[t] += cumsum[t - 1]; | ||||||||||
| } | ||||||||||
| __syncthreads(); | ||||||||||
|
|
||||||||||
| T *__restrict__ d = reinterpret_cast<T *>(d_base); | ||||||||||
| const T *__restrict__ bias = reinterpret_cast<const T *>(bias_base); | ||||||||||
|
|
||||||||||
| // Binary search for the starting row's tensor. | ||||||||||
| int tensor_idx; | ||||||||||
| { | ||||||||||
| int lo = 0, hi = num_tensors; | ||||||||||
| while (lo < hi) { | ||||||||||
| int mid = (lo + hi) >> 1; | ||||||||||
| if (cumsum[mid + 1] <= row_start) | ||||||||||
| lo = mid + 1; | ||||||||||
| else | ||||||||||
| hi = mid; | ||||||||||
| } | ||||||||||
| tensor_idx = lo; | ||||||||||
| } | ||||||||||
| int bias_idx = tensor_idx * n; | ||||||||||
|
Comment on lines
+893
to
+906
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have we benchmarked whether this binary search is any better than just scanning through the tensors. Computing the cumsums is still O(n), so we're not improving the asymptotics. We're also introducing thread syncs and shared memory accesses. |
||||||||||
|
|
||||||||||
| const int64_t elements = m * n; | ||||||||||
| const int64_t vec_count = elements / kVec; | ||||||||||
| using VecStorage = transformer_engine::VectorizedStorage<T, kVec>; | ||||||||||
| using VecType = typename VecStorage::LType; | ||||||||||
| transformer_engine::VectorizedLoader<T, kVec, true> loader(d_ptr, elements); | ||||||||||
| transformer_engine::VectorizedStorer<T, kVec, true> storer(d_ptr, elements); | ||||||||||
| const int64_t vec_id = static_cast<int64_t>(blockIdx.y) * blockDim.x + threadIdx.x; | ||||||||||
| if (vec_id >= vec_count) return; | ||||||||||
| const int64_t vec_start = vec_id * kVec; | ||||||||||
| const int64_t col = vec_start % n; | ||||||||||
| loader.load(vec_id, elements); | ||||||||||
| const auto *b_vec = reinterpret_cast<const VecType *>(bias_ptr + col); | ||||||||||
| VecStorage b_in; | ||||||||||
| b_in.scratch_.aligned = *b_vec; | ||||||||||
| b_in.scratch_.aligned = *reinterpret_cast<const VecType *>(bias + bias_idx + col); | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This value is immediately wiped out in the loop. I guess the compiler might be smart enough not to do an unnecessary memory access, but it makes the code harder to read. |
||||||||||
|
|
||||||||||
| // Walk tensor segments within this block's row range. | ||||||||||
| int seg_start = row_start; | ||||||||||
| while (seg_start < row_end) { | ||||||||||
| while (tensor_idx < num_tensors - 1 && cumsum[tensor_idx + 1] <= seg_start) { | ||||||||||
| tensor_idx++; | ||||||||||
| bias_idx += n; | ||||||||||
| } | ||||||||||
| b_in.scratch_.aligned = *reinterpret_cast<const VecType *>(bias + bias_idx + col); | ||||||||||
| const int seg_end = min(cumsum[tensor_idx + 1], row_end); | ||||||||||
|
|
||||||||||
| for (int row = seg_start; row < seg_end; row++) { | ||||||||||
| T *d_ptr = d + row * n + col; | ||||||||||
| VecStorage d_in; | ||||||||||
| d_in.scratch_.aligned = *reinterpret_cast<const VecType *>(d_ptr); | ||||||||||
|
|
||||||||||
| [[maybe_unused]] float s_val; | ||||||||||
| if constexpr (UseScale) s_val = scale_base[row]; | ||||||||||
|
|
||||||||||
| #pragma unroll | ||||||||||
| for (int i = 0; i < kVec; ++i) { | ||||||||||
| storer.separate()[i] = loader.separate()[i] + b_in.scratch_.separate[i]; | ||||||||||
| for (int i = 0; i < kVec; ++i) { | ||||||||||
| if constexpr (UseScale) { | ||||||||||
| d_in.scratch_.separate[i] = | ||||||||||
| static_cast<T>(fmaf(static_cast<float>(b_in.scratch_.separate[i]), s_val, | ||||||||||
| static_cast<float>(d_in.scratch_.separate[i]))); | ||||||||||
| } else { | ||||||||||
| d_in.scratch_.separate[i] = static_cast<T>(static_cast<float>(d_in.scratch_.separate[i]) + | ||||||||||
| static_cast<float>(b_in.scratch_.separate[i])); | ||||||||||
| } | ||||||||||
| } | ||||||||||
| *reinterpret_cast<VecType *>(d_ptr) = d_in.scratch_.aligned; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| seg_start = seg_end; | ||||||||||
| } | ||||||||||
| storer.store(vec_id, elements); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // Single kernel that sets up all GEMM parameters. | ||||||||||
|
|
@@ -1308,12 +1368,13 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, | |||||||||
| } | ||||||||||
|
|
||||||||||
| void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, | ||||||||||
| cudaStream_t stream) { | ||||||||||
| const NVTETensor scale, cudaStream_t stream) { | ||||||||||
| NVTE_API_CALL(nvte_grouped_bias_add); | ||||||||||
| using namespace transformer_engine; | ||||||||||
|
|
||||||||||
| const GroupedTensor *outputD = convertNVTEGroupedTensorCheck(output); | ||||||||||
| const GroupedTensor *bias_tensor = convertNVTEGroupedTensorCheck(bias); | ||||||||||
| const Tensor *scale_tensor = convertNVTETensorCheck(scale); | ||||||||||
|
|
||||||||||
| NVTE_CHECK(outputD->num_tensors >= 1, "Grouped bias add: number of tensors must be at least 1"); | ||||||||||
| NVTE_CHECK(outputD->num_tensors == bias_tensor->num_tensors, | ||||||||||
|
|
@@ -1330,27 +1391,67 @@ void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTens | |||||||||
| "Grouped bias add requires uniform last dim for output and bias"); | ||||||||||
| NVTE_CHECK(outputD->get_common_last_dim() == bias_tensor->get_common_last_dim(), | ||||||||||
| "Grouped bias add: output and bias last dims must match"); | ||||||||||
| constexpr int kVec = 4; | ||||||||||
| NVTE_CHECK(outputD->get_common_last_dim() % kVec == 0, | ||||||||||
| "Grouped bias add requires last dim divisible by ", kVec); | ||||||||||
|
|
||||||||||
| const float *scale_ptr = nullptr; | ||||||||||
| if (scale_tensor->data.dptr != nullptr) { | ||||||||||
| NVTE_CHECK(scale_tensor->dtype() == DType::kFloat32, "Grouped bias add: scale must be float32"); | ||||||||||
| NVTE_CHECK(scale_tensor->data.shape.size() == 1, "Grouped bias add: scale must be 1D, got ", | ||||||||||
| scale_tensor->data.shape.size(), "D"); | ||||||||||
| const size_t total_rows = static_cast<size_t>(outputD->logical_shape.data[0]); | ||||||||||
| NVTE_CHECK(scale_tensor->data.shape[0] == total_rows, "Grouped bias add: scale size (", | ||||||||||
| scale_tensor->data.shape[0], ") must equal total rows (", total_rows, ")"); | ||||||||||
| scale_ptr = static_cast<const float *>(scale_tensor->data.dptr); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| const TensorShapeInfo d_meta = TensorShapeInfo::from_tensor(outputD); | ||||||||||
| const TensorShapeInfo bias_meta = TensorShapeInfo::from_tensor(bias_tensor); | ||||||||||
|
|
||||||||||
| const DType dtype = outputD->dtype(); | ||||||||||
| constexpr int kThreads = 256; | ||||||||||
| const size_t total_elements = static_cast<size_t>(outputD->logical_shape.data[0]) * | ||||||||||
| static_cast<size_t>(outputD->logical_shape.data[1]); | ||||||||||
| const size_t total_vec_count = (total_elements + kVec - 1) / kVec; | ||||||||||
| int blocks_per_tensor = static_cast<int>((total_vec_count + kThreads - 1) / kThreads); | ||||||||||
| const dim3 grid(outputD->num_tensors, blocks_per_tensor); | ||||||||||
|
|
||||||||||
| const int num_tensors = static_cast<int>(outputD->num_tensors); | ||||||||||
| NVTE_CHECK(num_tensors <= 256, "Grouped bias add supports at most 256 tensors, got ", | ||||||||||
| num_tensors); | ||||||||||
| const int total_rows = static_cast<int>(outputD->logical_shape.data[0]); | ||||||||||
| const int n = static_cast<int>(outputD->get_common_last_dim()); | ||||||||||
|
|
||||||||||
| // Use 128-bit vector loads: kVec=8 for 2-byte types (bf16/fp16), kVec=4 for fp32. | ||||||||||
| const size_t elem_size = typeToSize(dtype); | ||||||||||
| const int kVec = (elem_size <= 2) ? 8 : 4; | ||||||||||
| NVTE_CHECK(n % kVec == 0, "Grouped bias add requires last dim divisible by ", kVec); | ||||||||||
|
|
||||||||||
| constexpr int kRowsPerBlock = 8; | ||||||||||
| const int block_cols = kThreads * kVec; | ||||||||||
| const int col_blocks = (n + block_cols - 1) / block_cols; | ||||||||||
| const int row_blocks = (total_rows + kRowsPerBlock - 1) / kRowsPerBlock; | ||||||||||
| const dim3 grid(row_blocks, col_blocks); | ||||||||||
| const dim3 block(kThreads); | ||||||||||
|
|
||||||||||
| TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, T, { | ||||||||||
| grouped_bias_add_kernel<T, kVec><<<grid, block, 0, stream>>>( | ||||||||||
| static_cast<char *>(outputD->data.dptr), static_cast<const char *>(bias_tensor->data.dptr), | ||||||||||
| d_meta, bias_meta, outputD->num_tensors); | ||||||||||
| }); | ||||||||||
| auto launch = [&](auto use_scale_tag) { | ||||||||||
| constexpr bool kUseScale = decltype(use_scale_tag)::value; | ||||||||||
| if (elem_size <= 2) { | ||||||||||
| constexpr int kV = 8; | ||||||||||
| TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, T, { | ||||||||||
| grouped_bias_add_kernel<T, kV, kUseScale, kThreads, kRowsPerBlock> | ||||||||||
| <<<grid, block, 0, stream>>>(static_cast<char *>(outputD->data.dptr), | ||||||||||
| static_cast<const char *>(bias_tensor->data.dptr), | ||||||||||
| scale_ptr, d_meta, n, total_rows, num_tensors); | ||||||||||
| }); | ||||||||||
| } else { | ||||||||||
| constexpr int kV = 4; | ||||||||||
| TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, T, { | ||||||||||
| grouped_bias_add_kernel<T, kV, kUseScale, kThreads, kRowsPerBlock> | ||||||||||
| <<<grid, block, 0, stream>>>(static_cast<char *>(outputD->data.dptr), | ||||||||||
| static_cast<const char *>(bias_tensor->data.dptr), | ||||||||||
| scale_ptr, d_meta, n, total_rows, num_tensors); | ||||||||||
| }); | ||||||||||
| } | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| if (scale_ptr != nullptr) { | ||||||||||
| launch(std::true_type{}); | ||||||||||
| } else { | ||||||||||
| launch(std::false_type{}); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||||||||||
| } | ||||||||||
|
|
@@ -1392,7 +1493,7 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, | |||||||||
| } | ||||||||||
|
|
||||||||||
| void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, | ||||||||||
| cudaStream_t stream) { | ||||||||||
| const NVTETensor scale, cudaStream_t stream) { | ||||||||||
| NVTE_ERROR("nvte_grouped_bias_add requires cuBLAS 13.3+, but compile-time cuBLAS version is ", | ||||||||||
| CUBLAS_VERSION, ". Please upgrade to cuBLAS 13.3 (shipped with CUDA 13.2) or newer."); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -429,12 +429,14 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, | |
| NVTETensor workspace_setup, NVTETensor workspace_cublas, | ||
| NVTEGroupedMatmulConfig config, cudaStream_t stream); | ||
|
|
||
| /*! \brief Grouped bias add for grouped GEMM outputs. | ||
| /*! \brief Grouped Bias add for grouped GEMM outputs. | ||
| * | ||
| * When \p scale is a valid tensor: output[row,col] += bias[col] * scale[row], | ||
| * When \p scale is empty/null: output[row,col] += bias[col]. | ||
| * Requires uniform last-dimension across all output tensors and bias tensors. | ||
| */ | ||
| void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, | ||
| cudaStream_t stream); | ||
| const NVTETensor scale, cudaStream_t stream); | ||
|
Comment on lines
438
to
+439
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes more sense to create a separate API for |
||
|
|
||
| #ifdef __cplusplus | ||
| } // extern "C" | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -327,6 +327,7 @@ def general_grouped_gemm_for_grouped_tensor( | |||||
| accumulate: bool = False, | ||||||
| use_split_accumulator: bool = False, | ||||||
| bias=None, | ||||||
| bias_scale: Optional[torch.Tensor] = None, | ||||||
| grad: bool = False, | ||||||
| alpha: Optional[torch.Tensor] = None, | ||||||
| beta: Optional[torch.Tensor] = None, | ||||||
|
|
@@ -365,6 +366,9 @@ def general_grouped_gemm_for_grouped_tensor( | |||||
| "Apply bias manually after the GEMM." | ||||||
| ) | ||||||
|
|
||||||
| if bias_scale is not None and bias is None: | ||||||
| raise ValueError("bias_scale requires bias to be provided.") | ||||||
|
|
||||||
| num_tensors = B.num_tensors | ||||||
| rowwise = B.rowwise_data | ||||||
| device = rowwise.device if rowwise is not None else B.columnwise_data.device | ||||||
|
|
@@ -394,13 +398,17 @@ def general_grouped_gemm_for_grouped_tensor( | |||||
| sm_count = get_sm_count() | ||||||
| sm_count = sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))) | ||||||
|
|
||||||
| if bias_scale is None: | ||||||
| bias_scale = torch.empty(0, dtype=torch.float32, device=device) | ||||||
|
|
||||||
|
Comment on lines
+401
to
+403
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can avoid this overhead by making the
Suggested change
|
||||||
| return grouped_gemm_impl( | ||||||
| A, | ||||||
| transa, | ||||||
| B, | ||||||
| transb, | ||||||
| out, | ||||||
| bias, | ||||||
| bias_scale, | ||||||
| alpha, | ||||||
| beta, | ||||||
| workspace_setup, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Is there a reason we're reordering? If the import order causes problems, then that's a bug we need to fix. Otherwise, this ordering seems strangely unmotivated and haphazard. It's also considered good Python style to put third party imports before local imports (PEP 8).