@@ -168,11 +168,17 @@ template <Op OpType, typename T>
168168__global__ void allreduceAllPairs (T* buff, T* scratch, T* resultBuff,
169169 mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
170170 size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode,
171- int worldSize, size_t nelems, uint32_t flag ) {
171+ int worldSize, size_t nelems, uint32_t * deviceFlag, uint32_t numScratchBuff ) {
172172 // This version of allreduce only works for single nodes
173173 if (worldSize != nRanksPerNode) return ;
174174 if (sizeof (T) == 2 ) nelems = (nelems * sizeof (T) + sizeof (T)) / sizeof (int );
175175 const int nPeers = nRanksPerNode - 1 ;
176+
177+ uint32_t flag = deviceFlag[blockIdx.x ];
178+
179+ size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE/numScratchBuff : 0 ;
180+ channelScratchOffset = scratchBaseOffset;
181+
176182 const int nBlocksPerPeer = gridDim.x / nPeers;
177183 const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
178184 const int tid = threadIdx.x + localBlockIdx * blockDim.x ;
@@ -198,13 +204,17 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff,
198204 }
199205 dst[idx] = data;
200206 }
207+ __syncthreads ();
208+ if (threadIdx.x == 0 ) {
209+ deviceFlag[blockIdx.x ] = deviceFlag[blockIdx.x ] + 1 ;
210+ }
201211}
202212
203213template <Op OpType, typename T>
204214__global__ void __launch_bounds__ (1024 , 1 )
205215 allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
206216 size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
207- size_t nelems, uint32_t flag
217+ size_t nelems, uint32_t * deviceFlag, uint32_t numScratchBuff
208218#if defined(ENABLE_NPKIT)
209219 ,
210220 NpKitEventCollectContext* npKitEventCollectContexts, uint64_t * cpuTimestamp) {
@@ -247,6 +257,11 @@ __global__ void __launch_bounds__(1024, 1)
247257 const int nPeers = nRanksPerNode - 1 ;
248258 const size_t nPkts = nelems / 2 ;
249259
260+ uint32_t flag = (uint32_t ) deviceFlag[blockIdx.x ];
261+
262+ size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE/numScratchBuff : 0 ;
263+ channelScratchOffset = scratchBaseOffset;
264+
250265 int nelemsPerRank = nelems / worldSize;
251266 if ((nelemsPerRank % 2 )) nelemsPerRank = (nelemsPerRank * sizeof (T) + sizeof (T)) / sizeof (T);
252267
@@ -309,6 +324,8 @@ __global__ void __launch_bounds__(1024, 1)
309324 result[idx].x = data.x ;
310325 result[idx].y = data.y ;
311326 }
327+
328+ __syncthreads ();
312329#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \
313330 defined (ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT)
314331 NpKit::CollectGpuEventShm (NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0 , 0 , npkit_timestamp_entry, event_buffer,
@@ -319,6 +336,9 @@ __global__ void __launch_bounds__(1024, 1)
319336#if defined(ENABLE_NPKIT)
320337 NpKit::StoreGpuEventShm (npKitEventCollectContexts, event_buffer, event_buffer_head);
321338#endif
339+ if (threadIdx.x == 0 ) {
340+ deviceFlag[blockIdx.x ] = deviceFlag[blockIdx.x ] + 1 ;
341+ }
322342}
323343
324344template <Op OpType, typename T>
@@ -462,37 +482,41 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff,
462482 mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
463483 mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelInOffset,
464484 size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
465- size_t nelems, cudaStream_t stream) {
466- static uint32_t flag = 1 ;
485+ size_t nelems, cudaStream_t stream, uint32_t * deviceFlag7, uint32_t * deviceFlag28,
486+ uint32_t * deviceFlag56, uint32_t numScratchBuff) {
467487
488+ uint32_t * deviceFlag;
468489 if (sizeof (T) * nelems < worldSize * sizeof (int )) {
469490 int nBlocks = 7 ;
470491 int nThreadsPerBlock = 32 ;
471492 allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0 , stream>>>(
472493 (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
473- nRanksPerNode, worldSize, nelems, flag++ );
494+ nRanksPerNode, worldSize, nelems, deviceFlag7, numScratchBuff );
474495 } else if (sizeof (T) * nelems <= (1 << 14 )) {
475496 int nBlocks = 28 ;
476497 int nThreadsPerBlock = 512 ;
477498 allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0 , stream>>>(
478499 (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
479- nRanksPerNode, worldSize, nelems, flag++ );
500+ nRanksPerNode, worldSize, nelems, deviceFlag28, numScratchBuff );
480501 } else if (sizeof (T) * nelems <= (1 << 20 )) {
481502 int nBlocks = 28 ;
482503 int nThreadsPerBlock = 1024 ;
504+ deviceFlag = deviceFlag28;
483505 if (nelems >= 8192 ) {
484506 nBlocks = 56 ;
485507 nThreadsPerBlock = (nelems <= 76800 ) ? 512 : 1024 ;
508+ deviceFlag = deviceFlag56;
486509 }
487510#if defined(ENABLE_NPKIT)
488511 size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof (NpKitEvent);
489512 allreduce7<OpType><<<nBlocks, nThreadsPerBlock, NpkitSharedMemSize, stream>>>(
490513 (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
491- nRanksPerNode, worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
514+ nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff, NpKit::GetGpuEventCollectContexts (),
515+ NpKit::GetCpuTimestamp ());
492516#else
493517 allreduce7<OpType><<<nBlocks, nThreadsPerBlock, 0 , stream>>>((T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels,
494518 channelInOffset, channelScratchOffset, rank,
495- nRanksPerNode, worldSize, nelems, flag++ );
519+ nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff );
496520#endif
497521 } else {
498522 int nBlocks = 35 ;
0 commit comments