diff --git a/CMakeLists.txt b/CMakeLists.txt index 6dc0bf4b5f8..707f8caaca3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,7 @@ option(ENABLE_GOOGLEBENCH "Enable GOOGLE-benchmark usage" OFF) option(ENABLE_RAPIDJSON "Enable rapid-json usage" OFF) option(ENABLE_CNPY "Enable cnpy usage" OFF) option(ENABLE_CUSOLVERMP "Enable cusolvermp" OFF) +option(ENABLE_NCCL_PARALLEL_DEVICE "Enable NCCL-backed collectives in parallel_device" OFF) if(NOT DEFINED NVHPC_ROOT_DIR AND DEFINED ENV{NVHPC_ROOT}) set(NVHPC_ROOT_DIR @@ -451,6 +452,15 @@ if(USE_CUDA) if (USE_OPENMP AND OpenMP_CXX_FOUND) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=${OpenMP_CXX_FLAGS}" CACHE STRING "CUDA flags" FORCE) endif() + if (ENABLE_NCCL_PARALLEL_DEVICE) + if (NOT ENABLE_MPI) + message(FATAL_ERROR + "ENABLE_NCCL_PARALLEL_DEVICE requires ENABLE_MPI=ON.") + endif() + add_compile_definitions(__NCCL_PARALLEL_DEVICE) + include(cmake/SetupNccl.cmake) + abacus_setup_nccl(${ABACUS_BIN_NAME}) + endif() if (ENABLE_CUSOLVERMP) # Keep cuSOLVERMp discovery/linking logic in a dedicated module. include(cmake/SetupCuSolverMp.cmake) diff --git a/cmake/SetupNccl.cmake b/cmake/SetupNccl.cmake new file mode 100644 index 00000000000..56e8e10e7b0 --- /dev/null +++ b/cmake/SetupNccl.cmake @@ -0,0 +1,49 @@ +include_guard(GLOBAL) + +include(CheckIncludeFileCXX) + +function(abacus_setup_nccl target_name) + find_library(NCCL_LIBRARY NAMES nccl + HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR} + PATH_SUFFIXES lib lib64 comm_libs/nccl/lib) + find_path(NCCL_INCLUDE_DIR NAMES nccl.h + HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR} + PATHS ${CUDAToolkit_ROOT} + PATH_SUFFIXES include comm_libs/nccl/include) + + check_include_file_cxx("nccl.h" HAVE_NCCL_HEADER) + + if(NOT NCCL_LIBRARY) + set(NCCL_LIBRARY nccl) + endif() + + if(NOT NCCL_INCLUDE_DIR AND NOT HAVE_NCCL_HEADER) + message(FATAL_ERROR + "NCCL not found. Set NCCL_PATH or NVHPC_ROOT_DIR.") + endif() + + if(NCCL_INCLUDE_DIR) + message(STATUS "Found NCCL for parallel_device: ${NCCL_LIBRARY}") + else() + message(STATUS "Using default compiler/linker search paths for NCCL: ${NCCL_LIBRARY}") + endif() + if(NOT TARGET NCCL::NCCL) + add_library(NCCL::NCCL IMPORTED INTERFACE) + if(NCCL_INCLUDE_DIR) + set_target_properties(NCCL::NCCL PROPERTIES + INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}") + else() + set_target_properties(NCCL::NCCL PROPERTIES + INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}") + endif() + endif() + + if(NCCL_INCLUDE_DIR) + # `parallel_device.cpp` is compiled inside the later `base` OBJECT library, + # so the header path must also be visible to targets created in subdirs. + include_directories(${NCCL_INCLUDE_DIR}) + target_include_directories(${target_name} PRIVATE ${NCCL_INCLUDE_DIR}) + endif() + target_link_libraries(${target_name} NCCL::NCCL) +endfunction() diff --git a/source/source_base/para_gemm.cpp b/source/source_base/para_gemm.cpp index edb798554cc..36d99d53b80 100644 --- a/source/source_base/para_gemm.cpp +++ b/source/source_base/para_gemm.cpp @@ -105,7 +105,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(A_tmp_device_, max_colA * LDA); -#ifndef __CUDA_MPI +#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) isend_tmp_.resize(max_colA * LDA); #endif } @@ -133,7 +133,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(C_local_tmp_, size_C_local); -#ifndef __CUDA_MPI +#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) C_global_tmp_.resize(size_C_global); #endif } @@ -277,38 +277,27 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con if (this->gatherC) { -#ifdef __CUDA_MPI - T* Clocal_mpi = C_local; - T* Cglobal_mpi = C; -#else - T* Clocal_mpi = C_tmp_.data(); - T* Cglobal_mpi = nullptr; + T* reduce_tmp = nullptr; + T* gather_tmp = nullptr; +#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { - syncmem_d2h_op()(Clocal_mpi, C_local, size_C_local); - Cglobal_mpi = C_global_tmp_.data(); - } - else - { - Cglobal_mpi = C; + reduce_tmp = C_tmp_.data(); + gather_tmp = C_global_tmp_.data(); } #endif if (this->row_nproc > 1) { - Parallel_Common::reduce_data(Clocal_mpi, size_C_local, row_world); + Parallel_Common::reduce_dev(C_local, size_C_local, row_world, reduce_tmp); } - Parallel_Common::gatherv_data(Clocal_mpi, - size_C_local, - Cglobal_mpi, - recv_counts.data(), - displs.data(), - col_world); -#ifndef __CUDA_MPI - if (std::is_same::value) - { - syncmem_h2d_op()(C, Cglobal_mpi, size_C_global); - } -#endif + Parallel_Common::gatherv_dev(C_local, + size_C_local, + C, + recv_counts.data(), + displs.data(), + col_world, + reduce_tmp, + gather_tmp); } else { @@ -409,4 +398,4 @@ template class PGemmCN, base_device::DEVICE_GPU>; template class PGemmCN, base_device::DEVICE_GPU>; #endif -} // namespace ModuleBase \ No newline at end of file +} // namespace ModuleBase diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index 933064e2486..d8e9b690903 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -1,7 +1,272 @@ #include "parallel_device.h" + +#if defined(__MPI) && defined(__NCCL_PARALLEL_DEVICE) +#include "source_base/module_device/device_check.h" + +#include +#include + +#include +#include + +#include +#include +#include + +#ifndef CHECK_NCCL +#define CHECK_NCCL(func) \ + do \ + { \ + ncclResult_t status = (func); \ + if (status != ncclSuccess) \ + { \ + fprintf(stderr, "In File %s : NCCL API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + ncclGetErrorString(status), status); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) +#endif +#endif + #ifdef __MPI namespace Parallel_Common { +#if defined(__NCCL_PARALLEL_DEVICE) +namespace +{ +struct NcclCommContext +{ + ncclComm_t comm = nullptr; + cudaStream_t stream = nullptr; + int size = 0; +}; + +class NcclCommRegistry +{ + public: + ~NcclCommRegistry() + { + for (std::map::iterator it = contexts_.begin(); it != contexts_.end(); ++it) + { + if (it->second.comm != nullptr) + { + ncclCommDestroy(it->second.comm); + } + } + } + + NcclCommContext& get(MPI_Comm comm) + { + const MPI_Fint key = MPI_Comm_c2f(comm); + std::lock_guard lock(mutex_); + std::map::iterator found = contexts_.find(key); + if (found != contexts_.end()) + { + return found->second; + } + + int rank = 0; + int size = 0; + MPI_Comm_rank(comm, &rank); + MPI_Comm_size(comm, &size); + + NcclCommContext ctx; + ctx.size = size; + if (size > 1) + { + ncclUniqueId id; + if (rank == 0) + { + CHECK_NCCL(ncclGetUniqueId(&id)); + } + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, comm); + CHECK_NCCL(ncclCommInitRank(&ctx.comm, size, id, rank)); + } + + std::pair::iterator, bool> inserted = contexts_.insert(std::make_pair(key, ctx)); + return inserted.first->second; + } + + private: + std::map contexts_; + std::mutex mutex_; +}; + +NcclCommRegistry& get_nccl_registry() +{ + static NcclCommRegistry registry; + return registry; +} + +template +void nccl_bcast_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, const int count_scale = 1) +{ + NcclCommContext& ctx = get_nccl_registry().get(comm); + if (ctx.size <= 1 || n <= 0) + { + return; + } + CHECK_NCCL(ncclBroadcast(object, object, static_cast(n) * count_scale, datatype, 0, ctx.comm, ctx.stream)); + CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); +} + +template +void nccl_reduce_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, const int count_scale = 1) +{ + NcclCommContext& ctx = get_nccl_registry().get(comm); + if (ctx.size <= 1 || n <= 0) + { + return; + } + CHECK_NCCL(ncclAllReduce(object, object, static_cast(n) * count_scale, datatype, ncclSum, ctx.comm, ctx.stream)); + CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); +} + +template +void nccl_gatherv_impl(const T* sendbuf, + const int sendcount, + T* recvbuf, + const int* recvcounts, + const int* displs, + MPI_Comm& comm) +{ + NcclCommContext& ctx = get_nccl_registry().get(comm); + if (ctx.size <= 1) + { + if (sendbuf != recvbuf && sendcount > 0) + { + CHECK_CUDA(cudaMemcpy(recvbuf, sendbuf, static_cast(sendcount) * sizeof(T), cudaMemcpyDeviceToDevice)); + } + return; + } + + int chunk_count = 0; + int rank = 0; + MPI_Comm_rank(comm, &rank); + for (int i = 0; i < ctx.size; ++i) + { + if (recvcounts[i] > chunk_count) + { + chunk_count = recvcounts[i]; + } + } + if (recvcounts[rank] != sendcount) + { + throw std::runtime_error("nccl_gatherv_data: sendcount does not match recvcounts[rank]"); + } + if (chunk_count <= 0) + { + return; + } + + const size_t chunk_bytes = static_cast(chunk_count) * sizeof(T); + const size_t recv_bytes = chunk_bytes * ctx.size; + unsigned char* staged_send = nullptr; + unsigned char* staged_recv = nullptr; + + CHECK_CUDA(cudaMalloc(&staged_send, chunk_bytes)); + CHECK_CUDA(cudaMalloc(&staged_recv, recv_bytes)); + CHECK_CUDA(cudaMemsetAsync(staged_send, 0, chunk_bytes, ctx.stream)); + if (sendcount > 0) + { + CHECK_CUDA(cudaMemcpyAsync(staged_send, + sendbuf, + static_cast(sendcount) * sizeof(T), + cudaMemcpyDeviceToDevice, + ctx.stream)); + } + + CHECK_NCCL(ncclAllGather(staged_send, staged_recv, chunk_bytes, ncclUint8, ctx.comm, ctx.stream)); + + for (int i = 0; i < ctx.size; ++i) + { + if (recvcounts[i] > 0) + { + CHECK_CUDA(cudaMemcpyAsync(recvbuf + displs[i], + staged_recv + static_cast(i) * chunk_bytes, + static_cast(recvcounts[i]) * sizeof(T), + cudaMemcpyDeviceToDevice, + ctx.stream)); + } + } + + CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); + CHECK_CUDA(cudaFree(staged_send)); + CHECK_CUDA(cudaFree(staged_recv)); +} +} // namespace + +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(object, n, comm, ncclDouble); +} + +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclDouble, 2); +} + +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(object, n, comm, ncclFloat); +} + +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclFloat, 2); +} + +void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(object, n, comm, ncclDouble); +} + +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(reinterpret_cast(object), n, comm, ncclDouble, 2); +} + +void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(object, n, comm, ncclFloat); +} + +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(reinterpret_cast(object), n, comm, ncclFloat, 2); +} + +void nccl_gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} + +void nccl_gatherv_data(const std::complex* sendbuf, + int sendcount, + std::complex* recvbuf, + const int* recvcounts, + const int* displs, + MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} + +void nccl_gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} + +void nccl_gatherv_data(const std::complex* sendbuf, + int sendcount, + std::complex* recvbuf, + const int* recvcounts, + const int* displs, + MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} +#endif + void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) { MPI_Isend(buf, count, MPI_DOUBLE, dest, tag, comm, request); @@ -104,7 +369,7 @@ template struct object_cpu_point { bool alloc = false; - T* get(const T* object, const int& n, T* tmp_space = nullptr) + T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr) { T* object_cpu = nullptr; alloc = false; @@ -118,6 +383,11 @@ struct object_cpu_point { object_cpu = tmp_space; } + return object_cpu; + } + T* get(const T* object, const int& n, T* tmp_space = nullptr) + { + T* object_cpu = get_buffer(object, n, tmp_space); base_device::memory::synchronize_memory_op()(object_cpu, object, n); @@ -149,6 +419,10 @@ template struct object_cpu_point { bool alloc = false; + T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr) + { + return const_cast(object); + } T* get(const T* object, const int& n, T* tmp_space = nullptr) { return const_cast(object); @@ -175,4 +449,4 @@ template struct object_cpu_point, base_device::DEVICE_GPU>; #endif } // namespace Parallel_Common -#endif \ No newline at end of file +#endif diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 7293b375d74..a148226e9fd 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -5,6 +5,7 @@ #include "source_base/module_device/device.h" #include "source_base/module_device/memory_op.h" #include +#include namespace Parallel_Common { void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); @@ -32,11 +33,27 @@ void gatherv_data(const std::complex* sendbuf, int sendcount, std::compl void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +#if defined(__NCCL_PARALLEL_DEVICE) +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void nccl_gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void nccl_gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void nccl_gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +#endif + #ifndef __CUDA_MPI template struct object_cpu_point { bool alloc = false; + T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr); T* get(const T* object, const int& n, T* tmp_space = nullptr); void del(T* object); void sync_d2h(T* object_cpu, const T* object, const int& n); @@ -56,7 +73,6 @@ void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T* #else object_cpu_point o; T* object_cpu = o.get(object, count, tmp_space); - o.sync_d2h(object_cpu, object, count); send_data(object_cpu, count, dest, tag, comm); o.del(object_cpu); #endif @@ -76,7 +92,6 @@ void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MP #else object_cpu_point o; T* object_cpu = o.get(object, count, send_space); - o.sync_d2h(object_cpu, object, count); isend_data(object_cpu, count, dest, tag, comm, request); o.del(object_cpu); #endif @@ -94,7 +109,7 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta recv_data(object, count, source, tag, comm, status); #else object_cpu_point o; - T* object_cpu = o.get(object, count, tmp_space); + T* object_cpu = o.get_buffer(object, count, tmp_space); recv_data(object_cpu, count, source, tag, comm, status); o.sync_h2d(object, object_cpu, count); o.del(object_cpu); @@ -116,14 +131,25 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta template void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { +#if defined(__NCCL_PARALLEL_DEVICE) + if (std::is_same::value) + { + nccl_bcast_data(object, n, const_cast(comm)); + return; + } +#endif #ifdef __CUDA_MPI bcast_data(object, n, comm); #else object_cpu_point o; - T* object_cpu = o.get(object, n, tmp_space); - o.sync_d2h(object_cpu, object, n); + int rank = 0; + MPI_Comm_rank(comm, &rank); + T* object_cpu = rank == 0 ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space); bcast_data(object_cpu, n, comm); - o.sync_h2d(object, object_cpu, n); + if (rank != 0) + { + o.sync_h2d(object, object_cpu, n); + } o.del(object_cpu); #endif return; @@ -132,12 +158,18 @@ void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nul template void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { +#if defined(__NCCL_PARALLEL_DEVICE) + if (std::is_same::value) + { + nccl_reduce_data(object, n, const_cast(comm)); + return; + } +#endif #ifdef __CUDA_MPI reduce_data(object, n, comm); #else object_cpu_point o; T* object_cpu = o.get(object, n, tmp_space); - o.sync_d2h(object_cpu, object, n); reduce_data(object_cpu, n, comm); o.sync_h2d(object, object_cpu, n); o.del(object_cpu); @@ -155,6 +187,13 @@ void gatherv_dev(const T* sendbuf, T* tmp_sspace = nullptr, T* tmp_rspace = nullptr) { +#if defined(__NCCL_PARALLEL_DEVICE) + if (std::is_same::value) + { + nccl_gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); + return; + } +#endif #ifdef __CUDA_MPI gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); #else @@ -163,8 +202,7 @@ void gatherv_dev(const T* sendbuf, MPI_Comm_size(comm, &size); int gather_space = displs[size - 1] + recvcounts[size - 1]; T* sendbuf_cpu = o1.get(sendbuf, sendcount, tmp_sspace); - T* recvbuf_cpu = o2.get(recvbuf, gather_space, tmp_rspace); - o1.sync_d2h(sendbuf_cpu, sendbuf, sendcount); + T* recvbuf_cpu = o2.get_buffer(recvbuf, gather_space, tmp_rspace); gatherv_data(sendbuf_cpu, sendcount, recvbuf_cpu, recvcounts, displs, comm); o2.sync_h2d(recvbuf, recvbuf_cpu, gather_space); o1.del(sendbuf_cpu); @@ -177,4 +215,4 @@ void gatherv_dev(const T* sendbuf, #endif -#endif \ No newline at end of file +#endif