Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1b700d8
Initial code of dlopen for nccl shared library
seagater Feb 12, 2025
583c8c5
Update NCCL_DLSYM and fix clang-format
seagater Feb 14, 2025
054ac89
Add parameters MSCCLPP_ENABLE_SHARED_LIB and MSCCLPP_NCCL_LIB_PATH to…
seagater Feb 19, 2025
d41d35b
Merge main branch
seagater Feb 24, 2025
ad2d89f
Add nccl_ops.CommInitRank
seagater Feb 25, 2025
3abfd97
Update NCCL_DLSYM for nccl_ops.CommInitRank
seagater Feb 26, 2025
7e0ee3e
Update dlopen code of loading nccl shared library
seagater Mar 2, 2025
73a4ed0
Use nccl_comm in ncclComm for nccl_ops.CommInitRank
seagater Mar 2, 2025
56d9366
Add dlopen for ncclCommUserRank
seagater Mar 3, 2025
ead9fba
Update error message handling for dlopen shared library
seagater Mar 5, 2025
2eaff19
Revert unnecessary changes to keep the original code of main branch
seagater Mar 5, 2025
646a823
Add dlopen for ncclReduceScatter
seagater Mar 7, 2025
b55ce47
Merge branch 'main' into qinghuazhou/nccl-rccl-integration-dlopen
Binyang2014 Mar 10, 2025
abf96c2
Update the fallback support for using specific collective operations …
seagater Mar 12, 2025
04705b8
Create a internal ncclUniqueId array and use mscclpp::TcpBootstrap::a…
seagater Mar 15, 2025
7999956
Update names of variables and functions
seagater Mar 15, 2025
62be254
Merge main branch
seagater Mar 15, 2025
2a1df14
Add ncclRedOp_t op in ncclReduceScatter to support NCCL fallback
seagater Mar 15, 2025
c8f7368
Change type of mscclppNcclComm to 'void*'
seagater Mar 16, 2025
7f946fa
Implement TcpBootstrap::Impl::broadcast for broadcast mscclppNcclUniq…
seagater Mar 19, 2025
9042680
Merge branch 'main' into qinghuazhou/nccl-rccl-integration-dlopen
Binyang2014 Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 164 additions & 1 deletion apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#if defined(ENABLE_NPKIT)
#include <mscclpp/npkit/npkit.hpp>
#endif
#include <dlfcn.h>

#include "allgather.hpp"
#include "allreduce.hpp"
#include "broadcast.hpp"
Expand All @@ -35,6 +37,101 @@

#define NUM_CHANNELS_PER_CONNECTION 64

typedef enum mscclppNcclDlopenErr {
dlopenSuccess = 0,
dlopenError = 1,
} mscclppNcclDlopenErr_t;

typedef struct _mscclppNcclOps_t {
ncclResult_t (*CommInitRank)(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
ncclResult_t (*GetUniqueId)(ncclUniqueId* uniqueId);
ncclResult_t (*CommDestroy)(ncclComm_t comm);
ncclResult_t (*CommUserRank)(const ncclComm_t, int* rank);
ncclResult_t (*AllReduce)(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op,
ncclComm_t comm, cudaStream_t stream);
ncclResult_t (*AllGather)(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
ncclComm_t comm, cudaStream_t stream);
ncclResult_t (*Broadcast)(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream);
ncclResult_t (*ReduceScatter)(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype,
ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
} mscclppNcclOps_t;

mscclppNcclOps_t mscclppNcclOps;
void* mscclppNcclDlHandle = NULL;
bool mscclppNcclDlopenSharedLib = false;

#define QUOTE(symbol) #symbol

#define NCCL_DLSYM(_struct_, _handle_, _prefix_, _function_, _type_) \
do { \
_struct_._function_ = (_type_)dlsym((_handle_), QUOTE(_prefix_##_function_)); \
if (_struct_._function_ == NULL) { \
printf("Failed: dlsym error: Cannot open %s: %s\n", QUOTE(_prefix_##_function_), dlerror()); \
exit(dlopenError); \
} \
} while (0)

static inline int mscclppNcclDlopenInit() {
const char* ncclLibPath = mscclpp::env()->ncclSharedLibPath.c_str();
if (ncclLibPath != nullptr && ncclLibPath[0] != '\0') {
if (std::filesystem::is_directory(ncclLibPath)) {
WARN("The value of the environment variable %s is a directory", ncclLibPath);
return dlopenError;
}

mscclppNcclDlHandle = dlopen(ncclLibPath, RTLD_LAZY | RTLD_NODELETE);
if (!mscclppNcclDlHandle) {
WARN("Cannot open the shared library specified by MSCCLPP_NCCL_LIB_PATH: %s\n", dlerror());
return dlopenError;
}
} else {
WARN("The value of MSCCLPP_NCCL_LIB_PATH is empty!\n");
return dlopenError;
}

NCCL_DLSYM(mscclppNcclOps, mscclppNcclDlHandle, nccl, CommInitRank,
ncclResult_t(*)(ncclComm_t*, int, ncclUniqueId, int));
NCCL_DLSYM(mscclppNcclOps, mscclppNcclDlHandle, nccl, GetUniqueId, ncclResult_t(*)(ncclUniqueId*));
NCCL_DLSYM(mscclppNcclOps, mscclppNcclDlHandle, nccl, CommDestroy, ncclResult_t(*)(ncclComm_t));
NCCL_DLSYM(mscclppNcclOps, mscclppNcclDlHandle, nccl, CommUserRank, ncclResult_t(*)(ncclComm_t, int*));
NCCL_DLSYM(mscclppNcclOps, mscclppNcclDlHandle, nccl, AllReduce,
ncclResult_t(*)(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t));
NCCL_DLSYM(mscclppNcclOps, mscclppNcclDlHandle, nccl, AllGather,
ncclResult_t(*)(const void*, void*, size_t, ncclDataType_t, ncclComm_t, cudaStream_t));
NCCL_DLSYM(mscclppNcclOps, mscclppNcclDlHandle, nccl, Broadcast,
ncclResult_t(*)(const void*, void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t));
NCCL_DLSYM(mscclppNcclOps, mscclppNcclDlHandle, nccl, ReduceScatter,
ncclResult_t(*)(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t));

return dlopenSuccess;
}

static inline void mscclppNcclDlopenFinalize() {
if (mscclppNcclDlHandle) {
dlclose(mscclppNcclDlHandle);
}
}

static inline int mscclppNcclInFallbackList(const char* collOps, const char* fallbackList) {
if (fallbackList == nullptr || fallbackList[0] == '\0' || strcmp(fallbackList, "all") == 0) {
return 1;
}

char* fallbackListCopy = strdup(fallbackList);
char* token = strtok(fallbackListCopy, ",");
while (token != NULL) {
if (strcmp(collOps, token) == 0) {
free(fallbackListCopy);
return 1;
}
token = strtok(NULL, ",");
}

free(fallbackListCopy);
return 0;
}

// static const mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2,
// mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
// mscclpp::Transport::IB6, mscclpp::Transport::IB7};
Expand Down Expand Up @@ -96,6 +193,8 @@ struct ncclComm {

uint32_t numScratchBuff;
uint32_t buffFlag;

void* mscclppNcclComm;
};

static size_t ncclTypeSize(ncclDataType_t type) {
Expand Down Expand Up @@ -501,6 +600,34 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
NpKit::Init(rank);
}
#endif

const bool mscclppEnableNcclFallback = mscclpp::env()->enableNcclFallback;
if (mscclppEnableNcclFallback == true && mscclppNcclDlHandle == NULL) {
int dlopenStatus = mscclppNcclDlopenInit();
if (dlopenStatus == dlopenSuccess) {
mscclppNcclDlopenSharedLib = true;
} else {
return ncclInternalError;
}
}

if (mscclppNcclDlopenSharedLib == true) {
ncclUniqueId mscclppNcclUniqueId;
if (rank == 0) {
mscclppNcclOps.GetUniqueId(&mscclppNcclUniqueId);
}
// After broadcast, mscclppNcclUniqueId on each rank has the same ncclUniqueId
bootstrap->broadcast(&mscclppNcclUniqueId, sizeof(ncclUniqueId), 0);

commPtr->mscclppNcclComm = new ncclComm_t();
if (commPtr->mscclppNcclComm == nullptr) {
WARN("Failed to allocate memory for mscclppNcclComm");
return ncclInternalError;
}
mscclppNcclOps.CommInitRank(reinterpret_cast<ncclComm_t*>(commPtr->mscclppNcclComm), nranks, mscclppNcclUniqueId,
rank);
}

return ncclSuccess;
}

Expand All @@ -527,6 +654,13 @@ NCCL_API ncclResult_t ncclCommDestroy(ncclComm_t comm) {
NpKit::Shutdown();
}
#endif

if (mscclppNcclDlopenSharedLib == true) {
mscclppNcclOps.CommDestroy(*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm));
mscclppNcclDlopenFinalize();
delete static_cast<ncclComm_t*>(comm->mscclppNcclComm);
}

delete comm;
return ncclSuccess;
}
Expand Down Expand Up @@ -628,6 +762,11 @@ NCCL_API ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank) {
WARN("comm is nullptr or rank is nullptr");
return ncclInvalidArgument;
}

if (mscclppNcclDlopenSharedLib == true) {
return mscclppNcclOps.CommUserRank(*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), rank);
}

*rank = comm->comm->bootstrap()->getRank();
return ncclSuccess;
}
Expand Down Expand Up @@ -715,6 +854,12 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
return ncclInvalidArgument;
}

const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("broadcast", fallbackList)) {
return mscclppNcclOps.Broadcast(sendbuff, recvbuff, count, datatype, root,
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
}

int rank = comm->comm->bootstrap()->getRank();

std::vector<executionPlanInstance>& plans = comm->executionPlans["broadcast"];
Expand Down Expand Up @@ -767,6 +912,12 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
return ncclInvalidArgument;
}

const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("allreduce", fallbackList)) {
return mscclppNcclOps.AllReduce(sendbuff, recvbuff, count, datatype, reductionOperation,
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
}

// Declarating variables
size_t bytes = count * ncclTypeSize(datatype);
int rank = comm->comm->bootstrap()->getRank();
Expand Down Expand Up @@ -811,7 +962,7 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
}

NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype,
ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) {
ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = recvcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
WARN(
Expand All @@ -820,6 +971,12 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, si
return ncclInvalidArgument;
}

const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("reducescatter", fallbackList)) {
return mscclppNcclOps.ReduceScatter(sendbuff, recvbuff, recvcount, datatype, op,
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
}

int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();

Expand Down Expand Up @@ -876,6 +1033,12 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
return ncclInvalidArgument;
}

const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("allgather", fallbackList)) {
return mscclppNcclOps.AllGather(sendbuff, recvbuff, sendcount, datatype,
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
}

int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();

Expand Down
11 changes: 11 additions & 0 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ class TcpBootstrap : public Bootstrap {
/// @param size The size of the data each rank sends.
void allGather(void* allData, int size) override;

/// Broadcast data from the root process to all processes using a ring-based algorithm.
///
/// When called by the root rank, this sends the `size` bytes starting at memory location `data` to all other
/// ranks. Non-root ranks receive these bytes into their own `data` buffer, overwriting its previous contents.
/// The data propagates sequentially through a logical ring of processes until all ranks have received it.
///
/// @param data Pointer to the send buffer (root) or receive buffer (non-root)
/// @param size Number of bytes to broadcast
/// @param root Rank initiating the broadcast
void broadcast(void* data, int size, int root);

/// Synchronize all processes.
void barrier() override;

Expand Down
3 changes: 3 additions & 0 deletions include/mscclpp/env.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class Env {
const std::string executionPlanDir;
const std::string npkitDumpDir;
const bool cudaIpcUseDefaultStream;
const std::string ncclSharedLibPath;
const std::string forceNcclFallbackOperation;
const bool enableNcclFallback;
const bool disableChannelCache;

private:
Expand Down
37 changes: 37 additions & 0 deletions src/bootstrap/bootstrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class TcpBootstrap::Impl {
int getNranks();
int getNranksPerNode();
void allGather(void* allData, int size);
void broadcast(void* data, int size, int root);
void send(void* data, int size, int peer, int tag);
void recv(void* data, int size, int peer, int tag);
void barrier();
Expand Down Expand Up @@ -465,6 +466,40 @@ void TcpBootstrap::Impl::allGather(void* allData, int size) {
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size);
}

void TcpBootstrap::Impl::broadcast(void* data, int size, int root) {
int rank = rank_;
int nRanks = nRanks_;

if (nRanks == 1) return;

TRACE(MSCCLPP_INIT, "rank %d nranks %d root %d size %d", rank, nRanks, root, size);

/*Ring-based broadcast, propagate data in nRanks-1 steps*/

// For ring-based broadcast, propagate data in nRanks-1 steps
for (int step = 0; step < nRanks - 1; step++) {
if (rank == root) {
// Root sends data to next rank in first step
if (step == 0) {
netSend(ringSendSocket_.get(), data, size);
}
} else {
// Calculate when this rank should receive data
int receiveStep = (rank - root - 1 + nRanks) % nRanks;
if (step == receiveStep) {
// Receive from previous rank
netRecv(ringRecvSocket_.get(), data, size);
// Forward to next rank (if not last step)
if (step < nRanks - 2) {
netSend(ringSendSocket_.get(), data, size);
}
}
}
}

TRACE(MSCCLPP_INIT, "rank %d nranks %d root %d size %d - DONE", rank, nRanks, root, size);
}

std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerSendSocket(int peer, int tag) {
auto it = peerSendSockets_.find(std::make_pair(peer, tag));
if (it != peerSendSockets_.end()) {
Expand Down Expand Up @@ -555,6 +590,8 @@ MSCCLPP_API_CPP void TcpBootstrap::recv(void* data, int size, int peer, int tag)

MSCCLPP_API_CPP void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }

MSCCLPP_API_CPP void TcpBootstrap::broadcast(void* data, int size, int root) { pimpl_->broadcast(data, size, root); }

MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId, int64_t timeoutSec) {
pimpl_->initialize(uniqueId, timeoutSec);
}
Expand Down
6 changes: 6 additions & 0 deletions src/env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ Env::Env()
executionPlanDir(readEnv<std::string>("MSCCLPP_EXECUTION_PLAN_DIR", "")),
npkitDumpDir(readEnv<std::string>("MSCCLPP_NPKIT_DUMP_DIR", "")),
cudaIpcUseDefaultStream(readEnv<bool>("MSCCLPP_CUDAIPC_USE_DEFAULT_STREAM", false)),
ncclSharedLibPath(readEnv<std::string>("MSCCLPP_NCCL_LIB_PATH", "")),
forceNcclFallbackOperation(readEnv<std::string>("MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION", "")),
enableNcclFallback(readEnv<bool>("MSCCLPP_ENABLE_NCCL_FALLBACK", false)),
disableChannelCache(readEnv<bool>("MSCCLPP_DISABLE_CHANNEL_CACHE", false)) {}

std::shared_ptr<Env> env() {
Expand All @@ -81,6 +84,9 @@ std::shared_ptr<Env> env() {
logEnv("MSCCLPP_EXECUTION_PLAN_DIR", globalEnv->executionPlanDir);
logEnv("MSCCLPP_NPKIT_DUMP_DIR", globalEnv->npkitDumpDir);
logEnv("MSCCLPP_CUDAIPC_USE_DEFAULT_STREAM", globalEnv->cudaIpcUseDefaultStream);
logEnv("MSCCLPP_NCCL_LIB_PATH", globalEnv->ncclSharedLibPath);
logEnv("MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION", globalEnv->forceNcclFallbackOperation);
logEnv("MSCCLPP_ENABLE_NCCL_FALLBACK", globalEnv->enableNcclFallback);
logEnv("MSCCLPP_DISABLE_CHANNEL_CACHE", globalEnv->disableChannelCache);
}
return globalEnv;
Expand Down
Loading