From 6d81876053ba546bad8af605e10d1f3d300407ca Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 11:22:20 +0800 Subject: [PATCH 1/9] Harden GPU MPI staging helpers --- source/source_base/parallel_device.cpp | 13 +++++++++++-- source/source_base/parallel_device.h | 21 +++++++++++---------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index 933064e2486..887dde3bebf 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -104,7 +104,7 @@ template struct object_cpu_point { bool alloc = false; - T* get(const T* object, const int& n, T* tmp_space = nullptr) + T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr) { T* object_cpu = nullptr; alloc = false; @@ -118,6 +118,11 @@ struct object_cpu_point { object_cpu = tmp_space; } + return object_cpu; + } + T* get(const T* object, const int& n, T* tmp_space = nullptr) + { + T* object_cpu = get_buffer(object, n, tmp_space); base_device::memory::synchronize_memory_op()(object_cpu, object, n); @@ -149,6 +154,10 @@ template struct object_cpu_point { bool alloc = false; + T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr) + { + return const_cast(object); + } T* get(const T* object, const int& n, T* tmp_space = nullptr) { return const_cast(object); @@ -175,4 +184,4 @@ template struct object_cpu_point, base_device::DEVICE_GPU>; #endif } // namespace Parallel_Common -#endif \ No newline at end of file +#endif diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 7293b375d74..40e325d8ac0 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -37,6 +37,7 @@ template struct object_cpu_point { bool alloc = false; + T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr); T* get(const T* object, const int& n, T* tmp_space = nullptr); void del(T* object); void sync_d2h(T* object_cpu, const T* object, const int& n); @@ -56,7 +57,6 @@ void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T* #else object_cpu_point o; T* object_cpu = o.get(object, count, tmp_space); - o.sync_d2h(object_cpu, object, count); send_data(object_cpu, count, dest, tag, comm); o.del(object_cpu); #endif @@ -76,7 +76,6 @@ void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MP #else object_cpu_point o; T* object_cpu = o.get(object, count, send_space); - o.sync_d2h(object_cpu, object, count); isend_data(object_cpu, count, dest, tag, comm, request); o.del(object_cpu); #endif @@ -94,7 +93,7 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta recv_data(object, count, source, tag, comm, status); #else object_cpu_point o; - T* object_cpu = o.get(object, count, tmp_space); + T* object_cpu = o.get_buffer(object, count, tmp_space); recv_data(object_cpu, count, source, tag, comm, status); o.sync_h2d(object, object_cpu, count); o.del(object_cpu); @@ -120,10 +119,14 @@ void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nul bcast_data(object, n, comm); #else object_cpu_point o; - T* object_cpu = o.get(object, n, tmp_space); - o.sync_d2h(object_cpu, object, n); + int rank = 0; + MPI_Comm_rank(comm, &rank); + T* object_cpu = rank == 0 ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space); bcast_data(object_cpu, n, comm); - o.sync_h2d(object, object_cpu, n); + if (rank != 0) + { + o.sync_h2d(object, object_cpu, n); + } o.del(object_cpu); #endif return; @@ -137,7 +140,6 @@ void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nu #else object_cpu_point o; T* object_cpu = o.get(object, n, tmp_space); - o.sync_d2h(object_cpu, object, n); reduce_data(object_cpu, n, comm); o.sync_h2d(object, object_cpu, n); o.del(object_cpu); @@ -163,8 +165,7 @@ void gatherv_dev(const T* sendbuf, MPI_Comm_size(comm, &size); int gather_space = displs[size - 1] + recvcounts[size - 1]; T* sendbuf_cpu = o1.get(sendbuf, sendcount, tmp_sspace); - T* recvbuf_cpu = o2.get(recvbuf, gather_space, tmp_rspace); - o1.sync_d2h(sendbuf_cpu, sendbuf, sendcount); + T* recvbuf_cpu = o2.get_buffer(recvbuf, gather_space, tmp_rspace); gatherv_data(sendbuf_cpu, sendcount, recvbuf_cpu, recvcounts, displs, comm); o2.sync_h2d(recvbuf, recvbuf_cpu, gather_space); o1.del(sendbuf_cpu); @@ -177,4 +178,4 @@ void gatherv_dev(const T* sendbuf, #endif -#endif \ No newline at end of file +#endif From 2a07d9b679ce2b15c687ec442200512e4fee22cb Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 11:32:50 +0800 Subject: [PATCH 2/9] Add NCCL collectives for parallel_device --- CMakeLists.txt | 14 ++ cmake/SetupNccl.cmake | 27 +++ source/source_base/parallel_device.cpp | 243 +++++++++++++++++++++++++ source/source_base/parallel_device.h | 37 ++++ 4 files changed, 321 insertions(+) create mode 100644 cmake/SetupNccl.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 6dc0bf4b5f8..02c22b87d99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,7 @@ option(ENABLE_GOOGLEBENCH "Enable GOOGLE-benchmark usage" OFF) option(ENABLE_RAPIDJSON "Enable rapid-json usage" OFF) option(ENABLE_CNPY "Enable cnpy usage" OFF) option(ENABLE_CUSOLVERMP "Enable cusolvermp" OFF) +option(ENABLE_NCCL_PARALLEL_DEVICE "Enable NCCL-backed collectives in parallel_device" OFF) if(NOT DEFINED NVHPC_ROOT_DIR AND DEFINED ENV{NVHPC_ROOT}) set(NVHPC_ROOT_DIR @@ -451,6 +452,19 @@ if(USE_CUDA) if (USE_OPENMP AND OpenMP_CXX_FOUND) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=${OpenMP_CXX_FLAGS}" CACHE STRING "CUDA flags" FORCE) endif() + if (ENABLE_NCCL_PARALLEL_DEVICE) + if (NOT ENABLE_MPI) + message(FATAL_ERROR + "ENABLE_NCCL_PARALLEL_DEVICE requires ENABLE_MPI=ON.") + endif() + if (NOT USE_CUDA_MPI) + message(FATAL_ERROR + "ENABLE_NCCL_PARALLEL_DEVICE requires USE_CUDA_MPI=ON.") + endif() + add_compile_definitions(__NCCL_PARALLEL_DEVICE) + include(cmake/SetupNccl.cmake) + abacus_setup_nccl(${ABACUS_BIN_NAME}) + endif() if (ENABLE_CUSOLVERMP) # Keep cuSOLVERMp discovery/linking logic in a dedicated module. include(cmake/SetupCuSolverMp.cmake) diff --git a/cmake/SetupNccl.cmake b/cmake/SetupNccl.cmake new file mode 100644 index 00000000000..0ecf513b589 --- /dev/null +++ b/cmake/SetupNccl.cmake @@ -0,0 +1,27 @@ +include_guard(GLOBAL) + +function(abacus_setup_nccl target_name) + find_library(NCCL_LIBRARY NAMES nccl + HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR} + PATH_SUFFIXES lib lib64 comm_libs/nccl/lib) + find_path(NCCL_INCLUDE_DIR NAMES nccl.h + HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR} + PATHS ${CUDAToolkit_ROOT} + PATH_SUFFIXES include comm_libs/nccl/include) + + if(NOT NCCL_LIBRARY OR NOT NCCL_INCLUDE_DIR) + message(FATAL_ERROR + "NCCL not found. Set NCCL_PATH or NVHPC_ROOT_DIR.") + endif() + + message(STATUS "Found NCCL for parallel_device: ${NCCL_LIBRARY}") + if(NOT TARGET NCCL::NCCL) + add_library(NCCL::NCCL IMPORTED INTERFACE) + set_target_properties(NCCL::NCCL PROPERTIES + INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}") + endif() + + include_directories(${NCCL_INCLUDE_DIR}) + target_link_libraries(${target_name} NCCL::NCCL) +endfunction() diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index 887dde3bebf..10c5dc489e8 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -1,7 +1,250 @@ #include "parallel_device.h" + +#if defined(__MPI) && defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +#include "source_base/module_device/device_check.h" + +#include +#include +#include +#endif + #ifdef __MPI namespace Parallel_Common { +#if defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +namespace +{ +struct NcclCommContext +{ + ncclComm_t comm = nullptr; + cudaStream_t stream = nullptr; + int size = 0; +}; + +class NcclCommRegistry +{ + public: + ~NcclCommRegistry() + { + for (std::map::iterator it = contexts_.begin(); it != contexts_.end(); ++it) + { + if (it->second.stream != nullptr) + { + cudaStreamDestroy(it->second.stream); + } + if (it->second.comm != nullptr) + { + ncclCommDestroy(it->second.comm); + } + } + } + + NcclCommContext& get(MPI_Comm comm) + { + const MPI_Fint key = MPI_Comm_c2f(comm); + std::lock_guard lock(mutex_); + std::map::iterator found = contexts_.find(key); + if (found != contexts_.end()) + { + return found->second; + } + + int rank = 0; + int size = 0; + MPI_Comm_rank(comm, &rank); + MPI_Comm_size(comm, &size); + + NcclCommContext ctx; + ctx.size = size; + if (size > 1) + { + ncclUniqueId id; + if (rank == 0) + { + CHECK_NCCL(ncclGetUniqueId(&id)); + } + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, comm); + CHECK_NCCL(ncclCommInitRank(&ctx.comm, size, id, rank)); + CHECK_CUDA(cudaStreamCreateWithFlags(&ctx.stream, cudaStreamNonBlocking)); + } + + std::pair::iterator, bool> inserted = contexts_.insert(std::make_pair(key, ctx)); + return inserted.first->second; + } + + private: + std::map contexts_; + std::mutex mutex_; +}; + +NcclCommRegistry& get_nccl_registry() +{ + static NcclCommRegistry registry; + return registry; +} + +template +void nccl_bcast_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, const int count_scale = 1) +{ + NcclCommContext& ctx = get_nccl_registry().get(comm); + if (ctx.size <= 1 || n <= 0) + { + return; + } + CHECK_NCCL(ncclBroadcast(object, object, static_cast(n) * count_scale, datatype, 0, ctx.comm, ctx.stream)); + CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); +} + +template +void nccl_reduce_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, const int count_scale = 1) +{ + NcclCommContext& ctx = get_nccl_registry().get(comm); + if (ctx.size <= 1 || n <= 0) + { + return; + } + CHECK_NCCL(ncclAllReduce(object, object, static_cast(n) * count_scale, datatype, ncclSum, ctx.comm, ctx.stream)); + CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); +} + +template +void nccl_gatherv_impl(const T* sendbuf, + const int sendcount, + T* recvbuf, + const int* recvcounts, + const int* displs, + MPI_Comm& comm) +{ + NcclCommContext& ctx = get_nccl_registry().get(comm); + if (ctx.size <= 1) + { + if (sendbuf != recvbuf && sendcount > 0) + { + CHECK_CUDA(cudaMemcpy(recvbuf, sendbuf, static_cast(sendcount) * sizeof(T), cudaMemcpyDeviceToDevice)); + } + return; + } + + int chunk_count = 0; + for (int i = 0; i < ctx.size; ++i) + { + if (recvcounts[i] > chunk_count) + { + chunk_count = recvcounts[i]; + } + } + if (chunk_count <= 0) + { + return; + } + + const size_t chunk_bytes = static_cast(chunk_count) * sizeof(T); + const size_t recv_bytes = chunk_bytes * ctx.size; + unsigned char* staged_send = nullptr; + unsigned char* staged_recv = nullptr; + + CHECK_CUDA(cudaMalloc(&staged_send, chunk_bytes)); + CHECK_CUDA(cudaMalloc(&staged_recv, recv_bytes)); + if (sendcount > 0) + { + CHECK_CUDA(cudaMemcpyAsync(staged_send, + sendbuf, + static_cast(sendcount) * sizeof(T), + cudaMemcpyDeviceToDevice, + ctx.stream)); + } + + CHECK_NCCL(ncclAllGather(staged_send, staged_recv, chunk_bytes, ncclUint8, ctx.comm, ctx.stream)); + + for (int i = 0; i < ctx.size; ++i) + { + if (recvcounts[i] > 0) + { + CHECK_CUDA(cudaMemcpyAsync(recvbuf + displs[i], + staged_recv + static_cast(i) * chunk_bytes, + static_cast(recvcounts[i]) * sizeof(T), + cudaMemcpyDeviceToDevice, + ctx.stream)); + } + } + + CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); + CHECK_CUDA(cudaFree(staged_send)); + CHECK_CUDA(cudaFree(staged_recv)); +} +} // namespace + +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(object, n, comm, ncclDouble); +} + +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclDouble, 2); +} + +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(object, n, comm, ncclFloat); +} + +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclFloat, 2); +} + +void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(object, n, comm, ncclDouble); +} + +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(reinterpret_cast(object), n, comm, ncclDouble, 2); +} + +void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(object, n, comm, ncclFloat); +} + +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(reinterpret_cast(object), n, comm, ncclFloat, 2); +} + +void nccl_gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} + +void nccl_gatherv_data(const std::complex* sendbuf, + int sendcount, + std::complex* recvbuf, + const int* recvcounts, + const int* displs, + MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} + +void nccl_gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} + +void nccl_gatherv_data(const std::complex* sendbuf, + int sendcount, + std::complex* recvbuf, + const int* recvcounts, + const int* displs, + MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} +#endif + void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) { MPI_Isend(buf, count, MPI_DOUBLE, dest, tag, comm, request); diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 40e325d8ac0..46beb7080b0 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -5,6 +5,7 @@ #include "source_base/module_device/device.h" #include "source_base/module_device/memory_op.h" #include +#include namespace Parallel_Common { void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); @@ -32,6 +33,21 @@ void gatherv_data(const std::complex* sendbuf, int sendcount, std::compl void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +#if defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void nccl_gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void nccl_gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void nccl_gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +#endif + #ifndef __CUDA_MPI template struct object_cpu_point @@ -116,6 +132,13 @@ template void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { #ifdef __CUDA_MPI +#if defined(__NCCL_PARALLEL_DEVICE) + if (std::is_same::value) + { + nccl_bcast_data(object, n, const_cast(comm)); + return; + } +#endif bcast_data(object, n, comm); #else object_cpu_point o; @@ -136,6 +159,13 @@ template void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { #ifdef __CUDA_MPI +#if defined(__NCCL_PARALLEL_DEVICE) + if (std::is_same::value) + { + nccl_reduce_data(object, n, const_cast(comm)); + return; + } +#endif reduce_data(object, n, comm); #else object_cpu_point o; @@ -158,6 +188,13 @@ void gatherv_dev(const T* sendbuf, T* tmp_rspace = nullptr) { #ifdef __CUDA_MPI +#if defined(__NCCL_PARALLEL_DEVICE) + if (std::is_same::value) + { + nccl_gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); + return; + } +#endif gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); #else object_cpu_point o1, o2; From 5f1cb8131452bcb1f8efafcfb6c37ed87cc2f694 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 11:43:07 +0800 Subject: [PATCH 3/9] Fix NCCL headers in parallel_device --- source/source_base/parallel_device.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index 10c5dc489e8..6badd60bd12 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -4,8 +4,27 @@ #include "source_base/module_device/device_check.h" #include +#include + #include #include + +#include +#include + +#ifndef CHECK_NCCL +#define CHECK_NCCL(func) \ + do \ + { \ + ncclResult_t status = (func); \ + if (status != ncclSuccess) \ + { \ + fprintf(stderr, "In File %s : NCCL API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + ncclGetErrorString(status), status); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) +#endif #endif #ifdef __MPI From 356fb280373adf39ccd61a950182d22383045430 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 11:50:56 +0800 Subject: [PATCH 4/9] Route PGemm collectives through device wrappers --- source/source_base/para_gemm.cpp | 41 ++++++++++++-------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/source/source_base/para_gemm.cpp b/source/source_base/para_gemm.cpp index edb798554cc..40913994eb4 100644 --- a/source/source_base/para_gemm.cpp +++ b/source/source_base/para_gemm.cpp @@ -277,38 +277,27 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con if (this->gatherC) { -#ifdef __CUDA_MPI - T* Clocal_mpi = C_local; - T* Cglobal_mpi = C; -#else - T* Clocal_mpi = C_tmp_.data(); - T* Cglobal_mpi = nullptr; + T* reduce_tmp = nullptr; + T* gather_tmp = nullptr; +#ifndef __CUDA_MPI if (std::is_same::value) { - syncmem_d2h_op()(Clocal_mpi, C_local, size_C_local); - Cglobal_mpi = C_global_tmp_.data(); - } - else - { - Cglobal_mpi = C; + reduce_tmp = C_tmp_.data(); + gather_tmp = C_global_tmp_.data(); } #endif if (this->row_nproc > 1) { - Parallel_Common::reduce_data(Clocal_mpi, size_C_local, row_world); + Parallel_Common::reduce_dev(C_local, size_C_local, row_world, reduce_tmp); } - Parallel_Common::gatherv_data(Clocal_mpi, - size_C_local, - Cglobal_mpi, - recv_counts.data(), - displs.data(), - col_world); -#ifndef __CUDA_MPI - if (std::is_same::value) - { - syncmem_h2d_op()(C, Cglobal_mpi, size_C_global); - } -#endif + Parallel_Common::gatherv_dev(C_local, + size_C_local, + C, + recv_counts.data(), + displs.data(), + col_world, + reduce_tmp, + gather_tmp); } else { @@ -409,4 +398,4 @@ template class PGemmCN, base_device::DEVICE_GPU>; template class PGemmCN, base_device::DEVICE_GPU>; #endif -} // namespace ModuleBase \ No newline at end of file +} // namespace ModuleBase From d7e8334157a19680aab6214fd5de789ba6ae7409 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 12:01:34 +0800 Subject: [PATCH 5/9] Tighten NCCL collective correctness --- source/source_base/parallel_device.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index 6badd60bd12..ab867ce9318 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -11,6 +11,7 @@ #include #include +#include #ifndef CHECK_NCCL #define CHECK_NCCL(func) \ @@ -47,10 +48,6 @@ class NcclCommRegistry { for (std::map::iterator it = contexts_.begin(); it != contexts_.end(); ++it) { - if (it->second.stream != nullptr) - { - cudaStreamDestroy(it->second.stream); - } if (it->second.comm != nullptr) { ncclCommDestroy(it->second.comm); @@ -84,7 +81,6 @@ class NcclCommRegistry } MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, comm); CHECK_NCCL(ncclCommInitRank(&ctx.comm, size, id, rank)); - CHECK_CUDA(cudaStreamCreateWithFlags(&ctx.stream, cudaStreamNonBlocking)); } std::pair::iterator, bool> inserted = contexts_.insert(std::make_pair(key, ctx)); @@ -145,6 +141,8 @@ void nccl_gatherv_impl(const T* sendbuf, } int chunk_count = 0; + int rank = 0; + MPI_Comm_rank(comm, &rank); for (int i = 0; i < ctx.size; ++i) { if (recvcounts[i] > chunk_count) @@ -152,6 +150,10 @@ void nccl_gatherv_impl(const T* sendbuf, chunk_count = recvcounts[i]; } } + if (recvcounts[rank] != sendcount) + { + throw std::runtime_error("nccl_gatherv_data: sendcount does not match recvcounts[rank]"); + } if (chunk_count <= 0) { return; @@ -164,6 +166,7 @@ void nccl_gatherv_impl(const T* sendbuf, CHECK_CUDA(cudaMalloc(&staged_send, chunk_bytes)); CHECK_CUDA(cudaMalloc(&staged_recv, recv_bytes)); + CHECK_CUDA(cudaMemsetAsync(staged_send, 0, chunk_bytes, ctx.stream)); if (sendcount > 0) { CHECK_CUDA(cudaMemcpyAsync(staged_send, From d1c59795691b61714aafd23b96e85a242b30acb0 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 12:18:03 +0800 Subject: [PATCH 6/9] Relax NCCL discovery for existing environments --- cmake/SetupNccl.cmake | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/cmake/SetupNccl.cmake b/cmake/SetupNccl.cmake index 0ecf513b589..31f2cb75c6c 100644 --- a/cmake/SetupNccl.cmake +++ b/cmake/SetupNccl.cmake @@ -1,5 +1,7 @@ include_guard(GLOBAL) +include(CheckIncludeFileCXX) + function(abacus_setup_nccl target_name) find_library(NCCL_LIBRARY NAMES nccl HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR} @@ -9,19 +11,36 @@ function(abacus_setup_nccl target_name) PATHS ${CUDAToolkit_ROOT} PATH_SUFFIXES include comm_libs/nccl/include) - if(NOT NCCL_LIBRARY OR NOT NCCL_INCLUDE_DIR) + check_include_file_cxx("nccl.h" HAVE_NCCL_HEADER) + + if(NOT NCCL_LIBRARY) + set(NCCL_LIBRARY nccl) + endif() + + if(NOT NCCL_INCLUDE_DIR AND NOT HAVE_NCCL_HEADER) message(FATAL_ERROR "NCCL not found. Set NCCL_PATH or NVHPC_ROOT_DIR.") endif() - message(STATUS "Found NCCL for parallel_device: ${NCCL_LIBRARY}") + if(NCCL_INCLUDE_DIR) + message(STATUS "Found NCCL for parallel_device: ${NCCL_LIBRARY}") + else() + message(STATUS "Using default compiler/linker search paths for NCCL: ${NCCL_LIBRARY}") + endif() if(NOT TARGET NCCL::NCCL) add_library(NCCL::NCCL IMPORTED INTERFACE) - set_target_properties(NCCL::NCCL PROPERTIES - INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}" - INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}") + if(NCCL_INCLUDE_DIR) + set_target_properties(NCCL::NCCL PROPERTIES + INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}") + else() + set_target_properties(NCCL::NCCL PROPERTIES + INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}") + endif() endif() - include_directories(${NCCL_INCLUDE_DIR}) + if(NCCL_INCLUDE_DIR) + target_include_directories(${target_name} PRIVATE ${NCCL_INCLUDE_DIR}) + endif() target_link_libraries(${target_name} NCCL::NCCL) endfunction() From bb01cefc26e66ce796e3ed2f52fa7ac6d1e23f33 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 17:46:49 +0800 Subject: [PATCH 7/9] Decouple NCCL parallel_device from CUDA-aware MPI --- CMakeLists.txt | 4 ---- source/source_base/para_gemm.cpp | 6 +++--- source/source_base/parallel_device.cpp | 4 ++-- source/source_base/parallel_device.h | 8 ++++---- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 02c22b87d99..707f8caaca3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -457,10 +457,6 @@ if(USE_CUDA) message(FATAL_ERROR "ENABLE_NCCL_PARALLEL_DEVICE requires ENABLE_MPI=ON.") endif() - if (NOT USE_CUDA_MPI) - message(FATAL_ERROR - "ENABLE_NCCL_PARALLEL_DEVICE requires USE_CUDA_MPI=ON.") - endif() add_compile_definitions(__NCCL_PARALLEL_DEVICE) include(cmake/SetupNccl.cmake) abacus_setup_nccl(${ABACUS_BIN_NAME}) diff --git a/source/source_base/para_gemm.cpp b/source/source_base/para_gemm.cpp index 40913994eb4..36d99d53b80 100644 --- a/source/source_base/para_gemm.cpp +++ b/source/source_base/para_gemm.cpp @@ -105,7 +105,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(A_tmp_device_, max_colA * LDA); -#ifndef __CUDA_MPI +#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) isend_tmp_.resize(max_colA * LDA); #endif } @@ -133,7 +133,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(C_local_tmp_, size_C_local); -#ifndef __CUDA_MPI +#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) C_global_tmp_.resize(size_C_global); #endif } @@ -279,7 +279,7 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con { T* reduce_tmp = nullptr; T* gather_tmp = nullptr; -#ifndef __CUDA_MPI +#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { reduce_tmp = C_tmp_.data(); diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index ab867ce9318..d8e9b690903 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -1,6 +1,6 @@ #include "parallel_device.h" -#if defined(__MPI) && defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +#if defined(__MPI) && defined(__NCCL_PARALLEL_DEVICE) #include "source_base/module_device/device_check.h" #include @@ -31,7 +31,7 @@ #ifdef __MPI namespace Parallel_Common { -#if defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +#if defined(__NCCL_PARALLEL_DEVICE) namespace { struct NcclCommContext diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 46beb7080b0..a148226e9fd 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -33,7 +33,7 @@ void gatherv_data(const std::complex* sendbuf, int sendcount, std::compl void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); -#if defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +#if defined(__NCCL_PARALLEL_DEVICE) void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm); void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm); @@ -131,7 +131,6 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta template void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { -#ifdef __CUDA_MPI #if defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { @@ -139,6 +138,7 @@ void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nul return; } #endif +#ifdef __CUDA_MPI bcast_data(object, n, comm); #else object_cpu_point o; @@ -158,7 +158,6 @@ void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nul template void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { -#ifdef __CUDA_MPI #if defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { @@ -166,6 +165,7 @@ void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nu return; } #endif +#ifdef __CUDA_MPI reduce_data(object, n, comm); #else object_cpu_point o; @@ -187,7 +187,6 @@ void gatherv_dev(const T* sendbuf, T* tmp_sspace = nullptr, T* tmp_rspace = nullptr) { -#ifdef __CUDA_MPI #if defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { @@ -195,6 +194,7 @@ void gatherv_dev(const T* sendbuf, return; } #endif +#ifdef __CUDA_MPI gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); #else object_cpu_point o1, o2; From aaaaad9e60b8eb41a6ac76966563199474bf85b7 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 18:19:07 +0800 Subject: [PATCH 8/9] Propagate NCCL headers to subdirectory targets --- cmake/SetupNccl.cmake | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmake/SetupNccl.cmake b/cmake/SetupNccl.cmake index 31f2cb75c6c..56e8e10e7b0 100644 --- a/cmake/SetupNccl.cmake +++ b/cmake/SetupNccl.cmake @@ -40,6 +40,9 @@ function(abacus_setup_nccl target_name) endif() if(NCCL_INCLUDE_DIR) + # `parallel_device.cpp` is compiled inside the later `base` OBJECT library, + # so the header path must also be visible to targets created in subdirs. + include_directories(${NCCL_INCLUDE_DIR}) target_include_directories(${target_name} PRIVATE ${NCCL_INCLUDE_DIR}) endif() target_link_libraries(${target_name} NCCL::NCCL) From 7c34f158f3bd998c9f05e4a76923ad8a80e40cde Mon Sep 17 00:00:00 2001 From: someone Date: Mon, 4 May 2026 12:19:53 +0800 Subject: [PATCH 9/9] Fix: narrow CPU staging guards in para_gemm to respect NCCL collectives MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit isend_dev has no NCCL path — keep guard as #ifndef __CUDA_MPI. reduce_dev / gatherv_dev have NCCL early-returns — exclude CPU staging when __NCCL_PARALLEL_DEVICE is defined (&& !defined). --- source/source_base/para_gemm.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/source_base/para_gemm.cpp b/source/source_base/para_gemm.cpp index 36d99d53b80..3e56aa83ac2 100644 --- a/source/source_base/para_gemm.cpp +++ b/source/source_base/para_gemm.cpp @@ -105,7 +105,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(A_tmp_device_, max_colA * LDA); -#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) +#ifndef __CUDA_MPI isend_tmp_.resize(max_colA * LDA); #endif } @@ -133,7 +133,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(C_local_tmp_, size_C_local); -#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) +#if !defined(__CUDA_MPI) && !defined(__NCCL_PARALLEL_DEVICE) C_global_tmp_.resize(size_C_global); #endif } @@ -279,7 +279,7 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con { T* reduce_tmp = nullptr; T* gather_tmp = nullptr; -#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) +#if !defined(__CUDA_MPI) && !defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { reduce_tmp = C_tmp_.data();