Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/guide/mscclpp-dsl.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,17 @@ The following figure shows how thread blocks and channels are replicated across
MSCCL++ DSL Instance Replication Overview
```

## Thread Block Group

This feature is currently a prototype. It allows you to define and use a set of thread blocks as a group to execute operations. By grouping thread blocks, you can allocate thread blocks in a non-uniform way, giving different operations different amounts of thread blocks as needed.

For example:
```python
# Create a Thread Block Group with 4 thread blocks
tbg = ThreadBlockGroup(tb_list=[0, 1, 2, 3])
# Use the Thread Block Group to perform the copy operation
rank.copy(output_buffer[0:1], input_buffer[0:1], tbg=tbg)
Comment thread
caiomcbr marked this conversation as resolved.
Outdated
```

## Execution plan

Expand Down
232 changes: 153 additions & 79 deletions python/mscclpp/language/channel.py

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions python/mscclpp/language/internal/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def to_dict(self):
return {"buffer_id": self.buffer_id, "index": self.index, "size": self.size}


@dataclass
class ThreadBlockGroupInfo:
tb_id: int
tbg_size: int

def to_dict(self):
return {"tb_id": self.tb_id, "tbg_size": self.tbg_size}


class SyncOperation(BaseOperation):
def __init__(self):
super().__init__(Instruction.nop)
Expand Down Expand Up @@ -139,6 +148,7 @@ def __init__(
self,
src_buff: List[LocalChunk],
dst_buff: List[LocalChunk],
tbg_info: ThreadBlockGroupInfo = None,
from_packet: bool = False,
to_packet: bool = False,
):
Expand All @@ -153,6 +163,7 @@ def __init__(

self.src_buff = src_buff
self.dst_buff = dst_buff
self.tbg_info = tbg_info

def local_data_access(self, sync_purpose=True):
data_access = []
Expand Down Expand Up @@ -182,6 +193,8 @@ def to_dict(self):
result["dst_buff"] = []
for chunk in self.dst_buff:
result["dst_buff"].append(chunk.to_dict())
if self.tbg_info is not None:
result["tbg_info"] = self.tbg_info.to_dict()
return result


Expand Down Expand Up @@ -431,12 +444,14 @@ def __init__(
dst_buff: List[LocalChunk],
channel_ids: List[int],
channel_type: ChannelType,
tbg_info: ThreadBlockGroupInfo = None,
):
super().__init__(Instruction.get)
self.src_buff = src_buff
self.dst_buff = dst_buff
self.channel_ids = channel_ids
self.channel_type = channel_type
self.tbg_info = tbg_info

def local_data_access(self, sync_purpose=True):
data_access = []
Expand All @@ -458,12 +473,14 @@ def __add__(self, other):
isinstance(other, GetOperation)
and self.src_buff[0].size == other.src_buff[0].size
and self.channel_type == other.channel_type
and self.tbg_info == other.tbg_info
):
fused_operation = GetOperation(
src_buff=self.src_buff + other.src_buff,
dst_buff=self.dst_buff + other.dst_buff,
channel_ids=self.channel_ids + other.channel_ids,
channel_type=self.channel_type,
tbg_info=self.tbg_info,
)

return fused_operation
Expand All @@ -478,6 +495,8 @@ def to_dict(self):
result["dst_buff"].append(chunk.to_dict())
result["channel_ids"] = self.channel_ids
result["channel_type"] = self.channel_type.value
if self.tbg_info is not None:
result["tbg_info"] = self.tbg_info.to_dict()
return result


Expand All @@ -488,6 +507,7 @@ def __init__(
dst_buff: List[RemoteChunk],
channel_ids: List[int],
channel_type: ChannelType,
tbg_info: ThreadBlockGroupInfo = None,
from_packet: bool = False,
to_packet: bool = False,
with_signal: bool = False,
Expand Down Expand Up @@ -517,6 +537,7 @@ def __init__(
self.to_packet = to_packet
self.with_signal = with_signal
self.with_signal_and_flush = with_signal_and_flush
self.tbg_info = tbg_info

def local_data_access(self, sync_purpose=True):
data_access = []
Expand Down Expand Up @@ -546,12 +567,14 @@ def __add__(self, other):
and self.name == other.name
and self.src_buff[0].size == other.src_buff[0].size
and self.channel_type == other.channel_type
and self.tbg_info == other.tbg_info
):
fused_operation = PutOperation(
src_buff=self.src_buff + other.src_buff,
dst_buff=self.dst_buff + other.dst_buff,
channel_ids=self.channel_ids + other.channel_ids,
channel_type=self.channel_type,
tbg_info=self.tbg_info,
to_packet=self.to_packet,
with_signal=self.with_signal,
with_signal_and_flush=self.with_signal_and_flush,
Expand All @@ -570,6 +593,8 @@ def to_dict(self):
if self.channel_type == ChannelType.port:
result["channel_ids"] = self.channel_ids
result["channel_type"] = self.channel_type.value
if self.tbg_info is not None:
result["tbg_info"] = self.tbg_info.to_dict()
return result


Expand All @@ -585,6 +610,7 @@ def __init__(
put_channel_ids: List[int] = None,
channel_type: ChannelType = ChannelType.none,
reduce_operation: ReduceOperationType = ReduceOperationType.sum,
tbg_info: ThreadBlockGroupInfo = None,
packet: bool = False,
):
remote_src_buff = remote_src_buff if remote_src_buff is not None else []
Expand Down Expand Up @@ -617,6 +643,7 @@ def __init__(
self.put_channel_ids = put_channel_ids
self.channel_type = channel_type
self.reduce_operation = reduce_operation
self.tbg_info = tbg_info
self.packet = packet

def local_data_access(self, sync_purpose=True):
Expand Down Expand Up @@ -657,6 +684,7 @@ def __add__(self, other):
and self.local_dst_buff == other.local_dst_buff
and self.channel_type == other.channel_type
and self.reduce_operation == other.reduce_operation
and self.tbg_info == other.tbg_info
):
fused_operation = ReduceOperation(
self.local_src_buff + other.local_src_buff[1:],
Expand All @@ -665,6 +693,7 @@ def __add__(self, other):
channel_ids=self.channel_ids + other.channel_ids,
channel_type=self.channel_type,
reduce_operation=self.reduce_operation,
tbg_info=self.tbg_info,
packet=self.packet,
)
if (
Expand All @@ -678,6 +707,7 @@ def __add__(self, other):
and other.name == Instruction.put
and self.local_dst_buff[0] == other.src_buff[0]
and other.channel_type == ChannelType.memory
and self.tbg_info == other.tbg_info
):
fused_operation = ReduceOperation(
self.local_src_buff,
Expand All @@ -688,6 +718,7 @@ def __add__(self, other):
put_channel_ids=self.put_channel_ids + other.channel_ids,
channel_type=self.channel_type,
reduce_operation=self.reduce_operation,
tbg_info=self.tbg_info,
packet=self.packet,
)
if (
Expand All @@ -696,6 +727,7 @@ def __add__(self, other):
and other.name == Instruction.put_packet
and self.local_dst_buff[0] == other.src_buff[0]
and other.channel_type == ChannelType.memory
and self.tbg_info == other.tbg_info
):
fused_operation = ReduceOperation(
self.local_src_buff,
Expand All @@ -706,6 +738,7 @@ def __add__(self, other):
put_channel_ids=self.put_channel_ids + other.channel_ids,
channel_type=other.channel_type,
reduce_operation=self.reduce_operation,
tbg_info=self.tbg_info,
packet=self.packet,
)

Expand All @@ -730,6 +763,8 @@ def to_dict(self):
if self.channel_type != ChannelType.none:
result["channel_type"] = self.channel_type.value
result["reduce_op"] = self.reduce_operation.value
if self.tbg_info is not None:
result["tbg_info"] = self.tbg_info.to_dict()
return result


Expand Down
102 changes: 74 additions & 28 deletions python/mscclpp/language/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

from mscclpp.language.internal.types import BufferType, Chunk
from mscclpp.language.thread_block_group import *
from mscclpp.language.internal.operations import *
from mscclpp.language.internal.globals import get_program
from dataclasses import dataclass
Expand Down Expand Up @@ -61,7 +62,15 @@ def get_output_buffer(self):
"""
return get_program().buffers[self.rank][BufferType.output]

def _copy(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int, from_packet: bool = False, to_packet: bool = False):
def _copy(
self,
dst_chunk: Chunk,
src_chunk: Chunk,
tb: int = None,
tb_group: ThreadBlockGroup = None,
from_packet: bool = False,
to_packet: bool = False,
):
"""Internal copy operation implementation.

Performs a local copy operation between chunks on this rank with optional
Expand All @@ -70,7 +79,8 @@ def _copy(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int, from_packet: bool =
Args:
dst_chunk (Chunk): The destination chunk to copy data to.
src_chunk (Chunk): The source chunk to copy data from.
tb (int): The thread block ID that will execute this operation.
tb (int, optional): The thread block ID that will execute this operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this operation. Defaults to None.
from_packet (bool, optional): Whether to unpack from packet format. Defaults to False.
to_packet (bool, optional): Whether to pack to packet format. Defaults to False.

Expand All @@ -91,16 +101,31 @@ def _copy(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int, from_packet: bool =
if to_packet and dst_chunk.buffer != BufferType.scratch:
raise RuntimeError(f"Destination chunk must be of type scratch.")

op = CopyOperation(
[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
from_packet,
to_packet,
)
if tb is not None:
tb_list = [tb]
elif tb_group is not None:
tb_list = tb_group.tb_list
else:
raise RuntimeError(
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

get_program().add_operation(self.rank, tb, op)
for tb_id in tb_list:
op = CopyOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
from_packet=from_packet,
to_packet=to_packet,
)

def copy(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
get_program().add_operation(self.rank, tb_id, op)

def copy(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Copy data from source chunk to destination chunk.

Performs a simple local copy operation between two chunks on this rank
Expand All @@ -109,14 +134,15 @@ def copy(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
Args:
dst_chunk (Chunk): The destination chunk to copy data to.
src_chunk (Chunk): The source chunk to copy data from.
tb (int): The thread block ID that will execute this operation.
tb (int, optional): The thread block ID that will execute this operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this operation. Defaults to None.

Example:
>>> rank.copy(dst_chunk, src_chunk, tb=0)
"""
self._copy(dst_chunk=dst_chunk, src_chunk=src_chunk, tb=tb)
self._copy(dst_chunk=dst_chunk, src_chunk=src_chunk, tb=tb, tb_group=tb_group)

def unpack_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
def unpack_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Copy data from packet format to regular format.

Unpacks data from packet format in the source scratch buffer and copies
Expand All @@ -125,14 +151,15 @@ def unpack_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
Args:
dst_chunk (Chunk): The destination chunk to copy unpacked data to.
src_chunk (Chunk): The source scratch chunk containing packed data.
tb (int): The thread block ID that will execute this operation.
tb (int, optional): The thread block ID that will execute this operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this operation. Defaults to None.

Example:
>>> rank.unpack_packet(dst_chunk, src_chunk, tb=0)
"""
self._copy(dst_chunk=dst_chunk, src_chunk=src_chunk, tb=tb, from_packet=True)
self._copy(dst_chunk=dst_chunk, src_chunk=src_chunk, tb=tb, tb_group=tb_group, from_packet=True)

def copy_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
def copy_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Copy data from regular format to packet format.

Packs data from the source chunk and copies it to the destination
Expand All @@ -141,18 +168,20 @@ def copy_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
Args:
dst_chunk (Chunk): The destination scratch chunk to store packed data.
src_chunk (Chunk): The source chunk containing data to pack.
tb (int): The thread block ID that will execute this operation.
tb (int, optional): The thread block ID that will execute this operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this operation. Defaults to None.

Example:
>>> rank.copy_packet(dst_chunk, src_chunk, tb=0)
"""
self._copy(dst_chunk=dst_chunk, src_chunk=src_chunk, tb=tb, to_packet=True)
self._copy(dst_chunk=dst_chunk, src_chunk=src_chunk, tb=tb, tb_group=tb_group, to_packet=True)

def reduce(
self,
src_chunk: Chunk,
other_chunks: List[Chunk],
tb: int,
tb: int = None,
tb_group: ThreadBlockGroup = None,
dst_chunk: Chunk = None,
reduce_op: ReduceOperationType = ReduceOperationType.sum,
packet: bool = False,
Expand All @@ -165,7 +194,8 @@ def reduce(
Args:
src_chunk (Chunk): The primary source chunk to reduce.
other_chunks (List[Chunk]): Additional chunks to include in the reduction.
tb (int): The thread block ID that will execute this operation.
tb (int, optional): The thread block ID that will execute this operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this operation. Defaults to None.
dst_chunk (Chunk, optional): The destination chunk for the result.
If None, uses src_chunk. Defaults to None.
reduce_op (ReduceOperationType, optional): The reduction operation to perform.
Expand Down Expand Up @@ -201,14 +231,30 @@ def reduce(
if packet and chunk.buffer != BufferType.scratch:
raise RuntimeError(f"Other chunk must be of type scratch.")

op = ReduceOperation(
[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)]
+ [LocalChunk(chunk.buffer, chunk.index, chunk.size) for chunk in other_chunks],
[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
reduce_operation=reduce_op,
packet=packet,
)
get_program().add_operation(self.rank, tb, op)
if tb is not None:
tb_list = [tb]
elif tb_group is not None:
tb_list = tb_group.tb_list
else:
raise RuntimeError(
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

for tb_id in tb_list:
op = ReduceOperation(
[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)]
+ [LocalChunk(chunk.buffer, chunk.index, chunk.size) for chunk in other_chunks],
[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
reduce_operation=reduce_op,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
packet=packet,
)

get_program().add_operation(self.rank, tb_id, op)

def barrier(self, tb_list: List[int]):
"""Create a synchronization barrier between thread blocks.
Expand Down
Loading
Loading