Skip to content

Commit 50cd2bb

Browse files
committed
apps/nccl: fix a bug in allreduce kernels for graph mode
1 parent 7a25e51 commit 50cd2bb

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

apps/nccl/src/allreduce.hpp

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,17 @@ template <Op OpType, typename T>
168168
__global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff,
169169
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
170170
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode,
171-
int worldSize, size_t nelems, uint32_t flag) {
171+
int worldSize, size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff) {
172172
// This version of allreduce only works for single nodes
173173
if (worldSize != nRanksPerNode) return;
174174
if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
175175
const int nPeers = nRanksPerNode - 1;
176+
177+
uint32_t flag = deviceFlag[blockIdx.x];
178+
179+
size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE/numScratchBuff : 0;
180+
channelScratchOffset = scratchBaseOffset;
181+
176182
const int nBlocksPerPeer = gridDim.x / nPeers;
177183
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
178184
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
@@ -198,13 +204,17 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff,
198204
}
199205
dst[idx] = data;
200206
}
207+
__syncthreads();
208+
if (threadIdx.x == 0) {
209+
deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1;
210+
}
201211
}
202212

203213
template <Op OpType, typename T>
204214
__global__ void __launch_bounds__(1024, 1)
205215
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
206216
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
207-
size_t nelems, uint32_t flag
217+
size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff
208218
#if defined(ENABLE_NPKIT)
209219
,
210220
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) {
@@ -247,6 +257,11 @@ __global__ void __launch_bounds__(1024, 1)
247257
const int nPeers = nRanksPerNode - 1;
248258
const size_t nPkts = nelems / 2;
249259

260+
uint32_t flag = (uint32_t) deviceFlag[blockIdx.x];
261+
262+
size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE/numScratchBuff : 0;
263+
channelScratchOffset = scratchBaseOffset;
264+
250265
int nelemsPerRank = nelems / worldSize;
251266
if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T);
252267

@@ -309,6 +324,8 @@ __global__ void __launch_bounds__(1024, 1)
309324
result[idx].x = data.x;
310325
result[idx].y = data.y;
311326
}
327+
328+
__syncthreads();
312329
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \
313330
defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT)
314331
NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer,
@@ -319,6 +336,9 @@ __global__ void __launch_bounds__(1024, 1)
319336
#if defined(ENABLE_NPKIT)
320337
NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head);
321338
#endif
339+
if (threadIdx.x == 0) {
340+
deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1;
341+
}
322342
}
323343

324344
template <Op OpType, typename T>
@@ -462,37 +482,41 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff,
462482
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
463483
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelInOffset,
464484
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
465-
size_t nelems, cudaStream_t stream) {
466-
static uint32_t flag = 1;
485+
size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, uint32_t* deviceFlag28,
486+
uint32_t* deviceFlag56, uint32_t numScratchBuff) {
467487

488+
uint32_t* deviceFlag;
468489
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
469490
int nBlocks = 7;
470491
int nThreadsPerBlock = 32;
471492
allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
472493
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
473-
nRanksPerNode, worldSize, nelems, flag++);
494+
nRanksPerNode, worldSize, nelems, deviceFlag7, numScratchBuff);
474495
} else if (sizeof(T) * nelems <= (1 << 14)) {
475496
int nBlocks = 28;
476497
int nThreadsPerBlock = 512;
477498
allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
478499
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
479-
nRanksPerNode, worldSize, nelems, flag++);
500+
nRanksPerNode, worldSize, nelems, deviceFlag28, numScratchBuff);
480501
} else if (sizeof(T) * nelems <= (1 << 20)) {
481502
int nBlocks = 28;
482503
int nThreadsPerBlock = 1024;
504+
deviceFlag = deviceFlag28;
483505
if (nelems >= 8192) {
484506
nBlocks = 56;
485507
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
508+
deviceFlag = deviceFlag56;
486509
}
487510
#if defined(ENABLE_NPKIT)
488511
size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent);
489512
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, NpkitSharedMemSize, stream>>>(
490513
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
491-
nRanksPerNode, worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
514+
nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff, NpKit::GetGpuEventCollectContexts(),
515+
NpKit::GetCpuTimestamp());
492516
#else
493517
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels,
494518
channelInOffset, channelScratchOffset, rank,
495-
nRanksPerNode, worldSize, nelems, flag++);
519+
nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff);
496520
#endif
497521
} else {
498522
int nBlocks = 35;

apps/nccl/src/nccl.cu

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ struct ncclComm {
195195
uint32_t numScratchBuff;
196196
uint32_t buffFlag;
197197

198+
std::shared_ptr<uint32_t> deviceFlag7;
199+
std::shared_ptr<uint32_t> deviceFlag28;
200+
std::shared_ptr<uint32_t> deviceFlag56;
201+
198202
void* mscclppNcclComm;
199203
};
200204

@@ -383,7 +387,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
383387
Op reduceOp = getReduceOp(op);
384388
std::function<cudaError_t(const void*, void*, void*, mscclpp::DeviceHandle<mscclpp::MemoryChannel>*,
385389
mscclpp::DeviceHandle<mscclpp::MemoryChannel>*, size_t, size_t, size_t, int, int, int,
386-
size_t, cudaStream_t)>
390+
size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)>
387391
allreduceFunc;
388392
if (reduceOp == SUM) {
389393
if (datatype == ncclFloat16) {
@@ -414,7 +418,8 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
414418
}
415419
CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, offsetIn,
416420
offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
417-
comm->comm->bootstrap()->getNranks(), count, stream));
421+
comm->comm->bootstrap()->getNranks(), count, stream, (uint32_t*)comm->deviceFlag7.get(),
422+
(uint32_t*)comm->deviceFlag28.get(), (uint32_t*)comm->deviceFlag56.get(), comm->numScratchBuff));
418423
return ncclSuccess;
419424
}
420425

@@ -533,6 +538,19 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
533538
commPtr->scratchBuff = mscclpp::GpuBuffer<char>(SCRATCH_SIZE).memory();
534539
commPtr->remoteScratchRegMemories =
535540
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
541+
542+
commPtr->deviceFlag7 = mscclpp::detail::gpuCallocShared<uint32_t>(7);
543+
commPtr->deviceFlag28 = mscclpp::detail::gpuCallocShared<uint32_t>(28);
544+
commPtr->deviceFlag56 = mscclpp::detail::gpuCallocShared<uint32_t>(56);
545+
546+
std::vector<uint32_t> initFlag(56);
547+
for (int i = 0; i < 56; ++i) {
548+
initFlag[i] = 1;
549+
}
550+
551+
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag7.get(), initFlag.data(), 7, cudaMemcpyHostToDevice);
552+
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag28.get(), initFlag.data(), 28, cudaMemcpyHostToDevice);
553+
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag56.get(), initFlag.data(), 56, cudaMemcpyHostToDevice);
536554
}
537555

538556
NCCL_API ncclResult_t ncclGetVersion(int* version) {

0 commit comments

Comments
 (0)