Skip to content

Commit 3a00896

Browse files
authored
Merge branch 'main' into binyli/nccl_api
2 parents e404b77 + dff3bc7 commit 3a00896

4 files changed

Lines changed: 105 additions & 9 deletions

File tree

python/mscclpp/language/internal/operations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def __init__(
534534
self.dst_buff = dst_buff
535535
self.channel_ids = channel_ids
536536
self.channel_type = channel_type
537+
self.from_packet = from_packet
537538
self.to_packet = to_packet
538539
self.with_signal = with_signal
539540
self.with_signal_and_flush = with_signal_and_flush
@@ -579,6 +580,25 @@ def __add__(self, other):
579580
with_signal=self.with_signal,
580581
with_signal_and_flush=self.with_signal_and_flush,
581582
)
583+
elif (
584+
isinstance(other, PutOperation)
585+
and self.name == Instruction.read_put_packet
586+
and self.name == other.name
587+
and self.src_buff == other.src_buff
588+
and self.channel_type == other.channel_type
589+
and self.tbg_info == other.tbg_info
590+
):
591+
fused_operation = PutOperation(
592+
src_buff=self.src_buff,
593+
dst_buff=self.dst_buff + other.dst_buff,
594+
channel_ids=self.channel_ids + other.channel_ids,
595+
channel_type=self.channel_type,
596+
tbg_info=self.tbg_info,
597+
from_packet=self.from_packet,
598+
to_packet=self.to_packet,
599+
with_signal=self.with_signal,
600+
with_signal_and_flush=self.with_signal_and_flush,
601+
)
582602

583603
return fused_operation
584604

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import argparse
5+
from mscclpp.language.channel import *
6+
from mscclpp.language.rank import *
7+
from mscclpp.language.general import *
8+
from mscclpp.language.program import *
9+
from mscclpp.language.collectives import *
10+
11+
12+
def allgather_example(name, gpu_size, num_threads_per_block, min_message_size, max_message_size):
13+
chunksperloop = 1
14+
collective = AllGather(gpu_size, chunksperloop, True)
15+
with CollectiveProgram(
16+
name,
17+
collective,
18+
gpu_size,
19+
protocol="LL",
20+
num_threads_per_block=num_threads_per_block,
21+
use_double_scratch_buffer=True,
22+
min_message_size=min_message_size,
23+
max_message_size=max_message_size,
24+
):
25+
# Creating Scratch Buffers
26+
scratch_buffer = []
27+
for gpu in range(gpu_size):
28+
scratch_buffer.append(Buffer(gpu, 2 * gpu_size))
29+
30+
# Copying it to scratch buffer
31+
for gpu in range(gpu_size):
32+
rank = Rank(gpu)
33+
scratch_offset = gpu_size
34+
input_buffer = rank.get_input_buffer()
35+
rank.copy_packets(
36+
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1], input_buffer[0:1], tb=0
37+
)
38+
39+
# Putting packets in the remote scratch buffer
40+
for gpu in range(gpu_size):
41+
rank = Rank(gpu)
42+
output_buffer = rank.get_output_buffer()
43+
for peer in range(1, gpu_size):
44+
dst_rank = (gpu + peer) % gpu_size
45+
ch = MemoryChannel(dst_rank, gpu)
46+
tb = 0
47+
ch.read_put_packets(
48+
scratch_buffer[dst_rank][gpu : gpu + 1],
49+
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1],
50+
tb,
51+
)
52+
53+
# Copying packets from local scratch buffer to local buffer
54+
for gpu in range(gpu_size):
55+
rank = Rank(gpu)
56+
output_buffer = rank.get_output_buffer()
57+
for peer in range(1, gpu_size):
58+
dst_rank = (gpu + peer) % gpu_size
59+
rank.unpack_packets(
60+
output_buffer[dst_rank : dst_rank + 1],
61+
scratch_buffer[gpu][dst_rank : dst_rank + 1],
62+
tb=0,
63+
)
64+
65+
print(JSON())
66+
67+
68+
parser = argparse.ArgumentParser()
69+
70+
parser.add_argument("--name", type=str, help="name of the program")
71+
parser.add_argument("--num_gpus", type=int, help="number of gpus")
72+
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
73+
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
74+
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
75+
76+
args = parser.parse_args()
77+
78+
allgather_example(args.name, args.num_gpus, args.num_threads_per_block, args.min_message_size, args.max_message_size)

python/test/executor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
env,
1212
)
1313
from mscclpp import CommGroup, GpuBuffer
14-
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
14+
from mscclpp.utils import KernelBuilder, pack
1515
import os
1616
import struct
1717

src/core/include/execution_kernel.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,11 @@ MSCCLPP_DEVICE_INLINE void handleReadPutPackets(const Operation& op, void* scrat
298298
ChannelType chType = op.channelType;
299299
if (chType == ChannelType::MEMORY) {
300300
size_t nPackets = size / sizeof(PacketPayload<PacketType>);
301+
PacketType* pkts = (PacketType*)((char*)scratch + scratchOffset_ + (srcOffsets[0] << 1));
301302
for (size_t pktIdx = threadIdx.x; pktIdx < nPackets; pktIdx += blockDim.x) {
303+
PacketPayload<PacketType> data = pkts[pktIdx].read(flag_);
304+
PacketType pkt(data, flag_);
302305
for (uint32_t idx = 0; idx < nOutput; ++idx) {
303-
PacketType* pkts = (PacketType*)((char*)scratch + scratchOffset_ + (srcOffsets[idx] << 1));
304-
PacketPayload<PacketType> data = pkts[pktIdx].read(flag_);
305-
PacketType pkt(data, flag_);
306306
size_t offset = (scratchOffset_ + (dstOffsets[idx] << 1)) / sizeof(PacketType);
307307
void* remoteMemory = static_cast<char*>(memoryChannelBufferPtrs_[op.outputBufferRefs[idx].id]);
308308
mscclpp::write<PacketType>(remoteMemory, offset + pktIdx, pkt);
@@ -312,10 +312,8 @@ MSCCLPP_DEVICE_INLINE void handleReadPutPackets(const Operation& op, void* scrat
312312
// Ensuring Data Is Ready
313313
size_t nPackets = size / sizeof(PacketPayload<PacketType>);
314314
for (size_t pktIdx = threadIdx.x; pktIdx < nPackets; pktIdx += blockDim.x) {
315-
for (uint32_t idx = 0; idx < nOutput; ++idx) {
316-
PacketType* pkts = (PacketType*)((char*)scratch + scratchOffset_ + (srcOffsets[idx] << 1));
317-
pkts[pktIdx].read(flag_);
318-
}
315+
PacketType* pkts = (PacketType*)((char*)scratch + scratchOffset_ + (srcOffsets[0] << 1));
316+
pkts[pktIdx].read(flag_);
319317
}
320318
__syncthreads();
321319

@@ -325,7 +323,7 @@ MSCCLPP_DEVICE_INLINE void handleReadPutPackets(const Operation& op, void* scrat
325323
return;
326324
}
327325
uint32_t dstOffset = (dstOffsets[chIdx] << 1) + scratchOffset_;
328-
uint32_t srcOffset = (srcOffsets[chIdx] << 1) + scratchOffset_;
326+
uint32_t srcOffset = (srcOffsets[0] << 1) + scratchOffset_;
329327
MemoryId dstMemoryId = portChannelBufferIds_[op.outputBufferRefs[chIdx].id];
330328
portChannels_[channelIndexes[chIdx]].put(
331329
dstMemoryId, dstOffset, static_cast<MemoryId>(BufferType::SCRATCH) + localMemoryIdBegin_, srcOffset, size << 1);

0 commit comments

Comments
 (0)