Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,17 @@ template <Op OpType, typename T>
__global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode,
int worldSize, size_t nelems, uint32_t flag) {
int worldSize, size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
const int nPeers = nRanksPerNode - 1;

uint32_t flag = deviceFlag[blockIdx.x];

size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE/numScratchBuff : 0;
channelScratchOffset = scratchBaseOffset;

const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
Expand All @@ -198,13 +204,17 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff,
}
dst[idx] = data;
}
__syncthreads();
Comment thread
nusislam marked this conversation as resolved.
if (threadIdx.x == 0) {
deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1;
}
}

template <Op OpType, typename T>
__global__ void __launch_bounds__(1024, 1)
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, uint32_t flag
size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff
#if defined(ENABLE_NPKIT)
,
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) {
Expand Down Expand Up @@ -247,6 +257,11 @@ __global__ void __launch_bounds__(1024, 1)
const int nPeers = nRanksPerNode - 1;
const size_t nPkts = nelems / 2;

uint32_t flag = (uint32_t) deviceFlag[blockIdx.x];

size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE/numScratchBuff : 0;
channelScratchOffset = scratchBaseOffset;

int nelemsPerRank = nelems / worldSize;
if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T);

Expand Down Expand Up @@ -309,6 +324,8 @@ __global__ void __launch_bounds__(1024, 1)
result[idx].x = data.x;
result[idx].y = data.y;
}

__syncthreads();
Comment thread
nusislam marked this conversation as resolved.
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \
defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT)
NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer,
Expand All @@ -319,6 +336,9 @@ __global__ void __launch_bounds__(1024, 1)
#if defined(ENABLE_NPKIT)
NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head);
#endif
if (threadIdx.x == 0) {
deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1;
}
}

template <Op OpType, typename T>
Expand Down Expand Up @@ -462,37 +482,41 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelInOffset,
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, cudaStream_t stream) {
static uint32_t flag = 1;
size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, uint32_t* deviceFlag28,
uint32_t* deviceFlag56, uint32_t numScratchBuff) {

uint32_t* deviceFlag;
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
int nBlocks = 7;
int nThreadsPerBlock = 32;
allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
nRanksPerNode, worldSize, nelems, deviceFlag7, numScratchBuff);
} else if (sizeof(T) * nelems <= (1 << 14)) {
int nBlocks = 28;
int nThreadsPerBlock = 512;
allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
nRanksPerNode, worldSize, nelems, deviceFlag28, numScratchBuff);
} else if (sizeof(T) * nelems <= (1 << 20)) {
int nBlocks = 28;
int nThreadsPerBlock = 1024;
deviceFlag = deviceFlag28;
if (nelems >= 8192) {
nBlocks = 56;
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
deviceFlag = deviceFlag56;
}
#if defined(ENABLE_NPKIT)
size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent);
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, NpkitSharedMemSize, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff, NpKit::GetGpuEventCollectContexts(),
NpKit::GetCpuTimestamp());
#else
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels,
channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff);
#endif
} else {
int nBlocks = 35;
Expand Down
22 changes: 20 additions & 2 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ struct ncclComm {
uint32_t numScratchBuff;
uint32_t buffFlag;

std::shared_ptr<uint32_t> deviceFlag7;
std::shared_ptr<uint32_t> deviceFlag28;
std::shared_ptr<uint32_t> deviceFlag56;

void* mscclppNcclComm;
};

Expand Down Expand Up @@ -383,7 +387,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
Op reduceOp = getReduceOp(op);
std::function<cudaError_t(const void*, void*, void*, mscclpp::DeviceHandle<mscclpp::MemoryChannel>*,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>*, size_t, size_t, size_t, int, int, int,
size_t, cudaStream_t)>
size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)>
allreduceFunc;
if (reduceOp == SUM) {
if (datatype == ncclFloat16) {
Expand Down Expand Up @@ -414,7 +418,8 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
}
CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, offsetIn,
offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
comm->comm->bootstrap()->getNranks(), count, stream, (uint32_t*)comm->deviceFlag7.get(),
(uint32_t*)comm->deviceFlag28.get(), (uint32_t*)comm->deviceFlag56.get(), comm->numScratchBuff));
return ncclSuccess;
}

Expand Down Expand Up @@ -533,6 +538,19 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
commPtr->scratchBuff = mscclpp::GpuBuffer<char>(SCRATCH_SIZE).memory();
commPtr->remoteScratchRegMemories =
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);

commPtr->deviceFlag7 = mscclpp::detail::gpuCallocShared<uint32_t>(7);
commPtr->deviceFlag28 = mscclpp::detail::gpuCallocShared<uint32_t>(28);
commPtr->deviceFlag56 = mscclpp::detail::gpuCallocShared<uint32_t>(56);

std::vector<uint32_t> initFlag(56);
for (int i = 0; i < 56; ++i) {
initFlag[i] = 1;
}

mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag7.get(), initFlag.data(), 7, cudaMemcpyHostToDevice);
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag28.get(), initFlag.data(), 28, cudaMemcpyHostToDevice);
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag56.get(), initFlag.data(), 56, cudaMemcpyHostToDevice);
}

NCCL_API ncclResult_t ncclGetVersion(int* version) {
Expand Down
Loading