From 9efb8650ee109fd891657400f12e60d81ecc1ee3 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 12 Jun 2026 19:35:48 +0800 Subject: [PATCH] issue/1276 - support bf16 in ascend causal softmax --- .../ascend/causal_softmax_ascend.cc | 184 +++++++++++------- 1 file changed, 116 insertions(+), 68 deletions(-) diff --git a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc index 813d16037..3b183ae0f 100644 --- a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc +++ b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc @@ -1,35 +1,58 @@ #include "causal_softmax_ascend.h" #include "../../../devices/ascend/common_ascend.h" +#include #include #include +#include namespace op::causal_softmax::ascend { +namespace { + +bool isCompact(const CausalSoftmaxInfo &info, ptrdiff_t stride_b, ptrdiff_t stride_i, ptrdiff_t stride_j) { + return stride_j == 1 + && stride_i == static_cast(info.total_seq_len) + && (info.batch_size == 1 || stride_b == static_cast(info.seq_len * info.total_seq_len)); +} + +} // namespace + struct Descriptor::Opaque { aclnnTensorDescriptor_t x; - aclnnTensorDescriptor_t temp; aclnnTensorDescriptor_t mask; aclnnTensorDescriptor_t y; aclnnTensorDescriptor_t value; + aclnnTensorDescriptor_t temp_x; + aclnnTensorDescriptor_t temp_y; void *mask_addr; void *value_addr; - void *temp_addr; - uint64_t workspacesize; + void *temp_x_addr; + void *temp_y_addr; + size_t workspacesize; aclOpExecutor *executor; + aclOpExecutor *temp_executor; + aclOpExecutor *copy_in_executor; + aclOpExecutor *copy_out_executor; + bool use_temp; ~Opaque() { delete x; - delete temp; delete mask; delete y; delete value; + delete temp_x; + delete temp_y; aclrtFree(mask_addr); aclrtFree(value_addr); - aclrtFree(temp_addr); + aclrtFree(temp_x_addr); + aclrtFree(temp_y_addr); // Delete useless executor aclDestroyAclOpExecutor(executor); + aclDestroyAclOpExecutor(temp_executor); + aclDestroyAclOpExecutor(copy_in_executor); + aclDestroyAclOpExecutor(copy_out_executor); } }; @@ -48,48 +71,44 @@ infiniStatus_t Descriptor::create( CausalSoftmaxInfo info = result.take(); aclOpExecutor *executor = nullptr; + aclOpExecutor *temp_executor = nullptr; aclOpExecutor *mask_executor = nullptr; + aclOpExecutor *copy_in_executor = nullptr; + aclOpExecutor *copy_out_executor = nullptr; aclnnTensorDescriptor_t y = nullptr; aclnnTensorDescriptor_t mask = nullptr; aclnnTensorDescriptor_t x = nullptr; aclnnTensorDescriptor_t value = nullptr; + aclnnTensorDescriptor_t temp_x = nullptr; + aclnnTensorDescriptor_t temp_y = nullptr; void *mask_addr = nullptr; void *value_addr = nullptr; + void *temp_x_addr = nullptr; + void *temp_y_addr = nullptr; size_t workspacesize_softmax = 0; + size_t workspacesize_temp_softmax = 0; size_t workspacesize_mask = 0; + size_t workspacesize_copy_in = 0; + size_t workspacesize_copy_out = 0; - // Create Aclnn Tensor Descriptors for input , mask and output + // Create Aclnn Tensor Descriptors for input, mask and output std::vector shape = {static_cast(info.batch_size), static_cast(info.seq_len), static_cast(info.total_seq_len)}; std::vector x_strides = {static_cast(info.x_stride_b), static_cast(info.x_stride_i), static_cast(info.x_stride_j)}; std::vector y_strides = {static_cast(info.y_stride_b), static_cast(info.y_stride_i), static_cast(info.y_stride_j)}; + std::vector compact_strides = {static_cast(info.seq_len * info.total_seq_len), static_cast(info.total_seq_len), 1}; y = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, y_strides); x = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, x_strides); - mask = new aclnnTensorDescriptor(aclDataType::ACL_BOOL, {static_cast(info.batch_size), static_cast(info.seq_len), static_cast(info.total_seq_len)}, {static_cast(info.seq_len * info.total_seq_len), static_cast(info.total_seq_len), 1}); - - // Allocate contiguous temp buffer for computation (avoids stride issues) - void *temp_addr = nullptr; - size_t temp_elements = info.batch_size * info.seq_len * info.total_seq_len; - size_t temp_bytes = temp_elements * aclDataTypeSize(toAclDataType(info.dtype)); - CHECK_ACL(aclrtMalloc(&temp_addr, temp_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); - std::vector temp_strides = { - static_cast(info.seq_len * info.total_seq_len), - static_cast(info.total_seq_len), - 1}; - auto temp = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, temp_strides, temp_addr); - - // Initialize the value tensor with -∞ + temp_x = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, compact_strides); + temp_y = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, compact_strides); + mask = new aclnnTensorDescriptor(aclDataType::ACL_BOOL, {static_cast(info.seq_len), static_cast(info.total_seq_len)}, {static_cast(info.total_seq_len), 1}); + + // Initialize the value tensor with -inf if (info.dtype == INFINI_DTYPE_F16) { uint16_t mask_value = 0xfc00; auto size = aclDataTypeSize(aclDataType::ACL_FLOAT16); CHECK_ACL(aclrtMalloc(&value_addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); CHECK_ACL(aclrtMemcpy(value_addr, size, &mask_value, size, ACL_MEMCPY_HOST_TO_DEVICE)); value = new aclnnTensorDescriptor(aclDataType::ACL_FLOAT16, {}, {}); - } else if (info.dtype == INFINI_DTYPE_BF16) { - uint16_t mask_value = 0xff80; - auto size = aclDataTypeSize(aclDataType::ACL_BF16); - CHECK_ACL(aclrtMalloc(&value_addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); - CHECK_ACL(aclrtMemcpy(value_addr, size, &mask_value, size, ACL_MEMCPY_HOST_TO_DEVICE)); - value = new aclnnTensorDescriptor(aclDataType::ACL_BF16, {}, {}); } else { uint32_t mask_value = 0xff800000; auto size = aclDataTypeSize(aclDataType::ACL_FLOAT); @@ -98,40 +117,58 @@ infiniStatus_t Descriptor::create( value = new aclnnTensorDescriptor(aclDataType::ACL_FLOAT, {}, {}); } - // Fill Mask Tensor (replicate 2D causal mask to all batches) - size_t mask_data_size = info.batch_size * info.seq_len * info.total_seq_len; - std::vector mask_matrix(mask_data_size, 0); + // Fill Mask Tensor + std::vector mask_matrix(mask->numel(), 0); for (size_t i = 0; i < info.seq_len; ++i) { for (size_t j = info.total_seq_len - info.seq_len + i + 1; j < info.total_seq_len; ++j) { - size_t index_2d = i * info.total_seq_len + j; - for (size_t b = 0; b < info.batch_size; ++b) { - mask_matrix[b * info.seq_len * info.total_seq_len + index_2d] = 1; - } + size_t index = i * info.total_seq_len + j; + mask_matrix[index] = 1; } } - auto size = mask_data_size * aclDataTypeSize(aclDataType::ACL_BOOL); + auto size = mask->numel() * aclDataTypeSize(aclDataType::ACL_BOOL); CHECK_ACL(aclrtMalloc(&mask_addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); CHECK_ACL(aclrtMemcpy(mask_addr, size, mask_matrix.data(), size, ACL_MEMCPY_HOST_TO_DEVICE)); // Get the workspace size for the op - aclTensor *ttemp = temp->tensor; + aclTensor *tx = x->tensor; aclTensor *ty = y->tensor; + aclTensor *ttemp_x = temp_x->tensor; + aclTensor *ttemp_y = temp_y->tensor; aclTensor *tmask = mask->tensor; aclTensor *tvalue = value->tensor; - CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp, tmask, tvalue, &workspacesize_mask, &mask_executor)); - - int64_t dim = 2; - CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(ttemp, dim, ty, &workspacesize_softmax, &executor)); - // set executor reusable - aclSetAclOpExecutorRepeatable(executor); + bool use_temp = !isCompact(info, info.x_stride_b, info.x_stride_i, info.x_stride_j) + || !isCompact(info, info.y_stride_b, info.y_stride_i, info.y_stride_j); + + if (use_temp) { + CHECK_ACL(aclnnInplaceCopyGetWorkspaceSize(ttemp_x, tx, &workspacesize_copy_in, ©_in_executor)); + aclSetAclOpExecutorRepeatable(copy_in_executor); + CHECK_ACL(aclnnInplaceCopyGetWorkspaceSize(ty, ttemp_y, &workspacesize_copy_out, ©_out_executor)); + aclSetAclOpExecutorRepeatable(copy_out_executor); + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp_x, tmask, tvalue, &workspacesize_mask, &mask_executor)); + int64_t dim = 2; + CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(ttemp_x, dim, ttemp_y, &workspacesize_temp_softmax, &temp_executor)); + aclSetAclOpExecutorRepeatable(temp_executor); + } else { + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); + int64_t dim = 2; + CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(tx, dim, ty, &workspacesize_softmax, &executor)); + // set executor reusable + aclSetAclOpExecutorRepeatable(executor); + } - // Create the descripto - size_t all_workspacesize = std::max(workspacesize_softmax, workspacesize_mask); + size_t op_workspace_size = std::max(std::max(workspacesize_softmax, workspacesize_temp_softmax), + std::max(workspacesize_mask, std::max(workspacesize_copy_in, workspacesize_copy_out))); + size_t all_workspacesize = op_workspace_size; + if (use_temp) { + size_t temp_bytes = temp_x->numel() * infiniSizeOf(info.dtype); + CHECK_ACL(aclrtMalloc(&temp_x_addr, temp_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + CHECK_ACL(aclrtMalloc(&temp_y_addr, temp_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + } - *desc_ptr = new Descriptor(new Opaque{x, temp, mask, y, value, mask_addr, value_addr, temp_addr, - workspacesize_softmax, executor}, + *desc_ptr = new Descriptor(new Opaque{x, mask, y, value, temp_x, temp_y, mask_addr, value_addr, + temp_x_addr, temp_y_addr, op_workspace_size, executor, temp_executor, copy_in_executor, copy_out_executor, use_temp}, std::move(info), all_workspacesize, handle_ascend->device, handle_ascend->device_id); return INFINI_STATUS_SUCCESS; @@ -141,38 +178,49 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi if (workspace_size < workspaceSize()) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } - auto ttemp = _opaque->temp->tensor; + auto tx = _opaque->x->tensor; auto ty = _opaque->y->tensor; auto tmask = _opaque->mask->tensor; auto tvalue = _opaque->value->tensor; - aclOpExecutor *mask_executor = nullptr; - size_t workspacesize_mask = 0; - // Copy x to contiguous temp buffer (handles custom stride correctly) - size_t dtype_sz = aclDataTypeSize(_opaque->temp->dataType); - size_t row_bytes = _info.total_seq_len * dtype_sz; - for (size_t b = 0; b < _info.batch_size; b++) { - size_t dst_batch_off = b * _info.seq_len * _info.total_seq_len * dtype_sz; - size_t src_batch_off = b * _info.x_stride_b * dtype_sz; - for (size_t i = 0; i < _info.seq_len; i++) { - aclrtMemcpy( - (char *)_opaque->temp_addr + dst_batch_off + i * _info.total_seq_len * dtype_sz, - row_bytes, - (const char *)x + src_batch_off + i * _info.x_stride_i * dtype_sz, - row_bytes, - ACL_MEMCPY_DEVICE_TO_DEVICE); - } + if (_opaque->use_temp) { + auto ttemp_x = _opaque->temp_x->tensor; + auto ttemp_y = _opaque->temp_y->tensor; + void *temp_x = _opaque->temp_x_addr; + void *temp_y = _opaque->temp_y_addr; + + AclSetTensorAddr(_opaque->copy_in_executor, 0, ttemp_x, temp_x); + AclSetTensorAddr(_opaque->copy_in_executor, 1, tx, (void *)x); + CHECK_ACL(aclnnInplaceCopy(workspace, _opaque->workspacesize, _opaque->copy_in_executor, stream)); + + aclOpExecutor *mask_executor = nullptr; + size_t workspacesize_mask = 0; + AclSetTensorAddr(mask_executor, 0, ttemp_x, temp_x); + AclSetTensorAddr(mask_executor, 1, tmask, _opaque->mask_addr); + AclSetTensorAddr(mask_executor, 2, tvalue, _opaque->value_addr); + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp_x, tmask, tvalue, &workspacesize_mask, &mask_executor)); + CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, _opaque->workspacesize, mask_executor, stream)); + + AclSetTensorAddr(_opaque->temp_executor, 0, ttemp_x, temp_x); + AclSetTensorAddr(_opaque->temp_executor, 1, ttemp_y, temp_y); + CHECK_ACL(aclnnSoftmax(workspace, _opaque->workspacesize, _opaque->temp_executor, stream)); + + AclSetTensorAddr(_opaque->copy_out_executor, 0, ty, y); + AclSetTensorAddr(_opaque->copy_out_executor, 1, ttemp_y, temp_y); + CHECK_ACL(aclnnInplaceCopy(workspace, _opaque->workspacesize, _opaque->copy_out_executor, stream)); + return INFINI_STATUS_SUCCESS; } - // Masked fill on temp (contiguous, no stride issues) - AclSetTensorAddr(mask_executor, 0, ttemp, _opaque->temp_addr); + aclOpExecutor *mask_executor = nullptr; + size_t workspacesize_mask = 0; + + AclSetTensorAddr(mask_executor, 0, tx, (void *)x); AclSetTensorAddr(mask_executor, 1, tmask, _opaque->mask_addr); AclSetTensorAddr(mask_executor, 2, tvalue, _opaque->value_addr); - CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp, tmask, tvalue, &workspacesize_mask, &mask_executor)); - CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, workspacesize_mask, mask_executor, stream)); + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); + CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, _opaque->workspacesize, mask_executor, stream)); - // Softmax temp (contiguous) → y - AclSetTensorAddr(_opaque->executor, 0, ttemp, _opaque->temp_addr); + AclSetTensorAddr(_opaque->executor, 0, tx, (void *)x); AclSetTensorAddr(_opaque->executor, 1, ty, y); CHECK_ACL(aclnnSoftmax(workspace, _opaque->workspacesize, _opaque->executor, stream));