Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c601de4
add pp stage checkers to p2p communicator
yashaswikarnati Jan 27, 2026
84ae4f0
add process group collection wrapper
yashaswikarnati Jan 27, 2026
0fa3dd8
support multimodule pipelining in 1f1b schedule
yashaswikarnati Jan 27, 2026
b22f638
fix dim mapping in torch cat bridge comm
yashaswikarnati Jan 28, 2026
3badf57
handle 3d 2d tensor conversion in multimodule comm
yashaswikarnati Jan 28, 2026
20d03f5
add unit tests for multimodule pipeline schedules
yashaswikarnati Jan 28, 2026
a6606d8
refactor multimodule pg collection and backward step
yashaswikarnati Jan 28, 2026
b102eb7
rename module_collections to module_pgs for clarity
yashaswikarnati Jan 28, 2026
ebbb509
rename tensor conversion functions for clarity
yashaswikarnati Jan 28, 2026
2d7c176
Merge branch 'main' into yash/1f1b_changes
dimapihtar Jan 29, 2026
0b6cefd
Fix linting issues: format code and remove unused imports
yashaswikarnati Feb 3, 2026
597862e
Merge branch 'main' into yash/1f1b_changes
shifangx Feb 14, 2026
5f941d1
test: fix isort formatting in multimodule schedule test
yashaswikarnati Feb 17, 2026
b1db431
handle encoder only ranks
yashaswikarnati Feb 18, 2026
5846567
cache PGs across bridge communicators
yashaswikarnati Feb 18, 2026
908ea5f
Merge branch 'main' into yash/1f1b_changes
shifangx Feb 23, 2026
ee189df
Guard ambiguous multimodule comm tensor shape
yashaswikarnati Feb 24, 2026
81cf623
move backward_step_dict to schedules.py
yashaswikarnati Mar 6, 2026
6542743
Merge branch 'main' into yash/1f1b_changes
shifangx Mar 13, 2026
4f01712
Refactor: expose total_stages/current_stage on communicators
yashaswikarnati Mar 14, 2026
738db94
Merge remote-tracking branch 'upstream/main' into yash/1f1b_changes
yashaswikarnati Mar 14, 2026
edc8159
Fix test isolation: destroy leaked NCCL process groups in multimodule…
yashaswikarnati Mar 16, 2026
92b65d1
Remove redundant pg_collection asserts from schedules.py
yashaswikarnati Mar 16, 2026
78ee58c
Add missing copyright header to test_bridge_communicator.py
yashaswikarnati Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions megatron/core/pipeline_parallel/bridge_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def recv_backward(self) -> torch.Tensor:
received_gradients_list.append(grad_tensor)

# Concatenate received gradients
aggregated_gradient = torch.cat(received_gradients_list, dim=0)
aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b'])
Comment thread
yashaswikarnati marked this conversation as resolved.
logging.debug(
f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} "
f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}"
Expand Down Expand Up @@ -615,7 +615,7 @@ def send_forward_recv_backward(
req.wait()

# Concatenate received gradients
aggregated_gradient = torch.cat(received_gradients_list, dim=0)
aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b'])
logging.debug(
f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} "
f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}"
Expand Down Expand Up @@ -737,9 +737,9 @@ def send_backward_recv_forward(
req.wait()

# Concatenate received activations
aggregated_activation = torch.cat(received_activations_list, dim=0)
aggregated_activation = torch.cat(received_activations_list, dim=self.dim_mapping['b'])
logging.debug(
f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} "
f"[Bridge Communicator] [send_backward_recv_forward] Rank {self.current_rank} "
f"agg act shape {aggregated_activation.shape} sum {aggregated_activation.sum()}"
)

Expand Down
118 changes: 76 additions & 42 deletions megatron/core/pipeline_parallel/multimodule_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,52 @@ class RankModuleInfo:
is_terminal_stage: Optional[bool] = True


def _prepare_tensor_for_comm(
tensor: Union[torch.Tensor, List[torch.Tensor], None]
) -> Union[torch.Tensor, List[torch.Tensor], None]:
"""Prepare tensor for P2P/bridge communication by expanding to 3D if needed.

P2P and bridge communicators expect 3D tensors.
Handles both single tensors and lists of tensors (for VPP).

Args:
tensor: Input tensor (2D or 3D), list of tensors, or None.

Returns:
3D tensor (with singleton last dim if input was 2D), list of 3D tensors, or None.
"""
if tensor is None:
Comment thread
yashaswikarnati marked this conversation as resolved.
return None
if isinstance(tensor, list):
return [_prepare_tensor_for_comm(t) for t in tensor]
if isinstance(tensor, torch.Tensor) and tensor.ndim == 2:
return tensor.unsqueeze(-1)
return tensor


def _restore_tensor_from_comm(
tensor: Union[torch.Tensor, List[torch.Tensor], None]
) -> Union[torch.Tensor, List[torch.Tensor], None]:
"""Restore tensor shape after P2P/bridge communication by squeezing singleton dim.

Removes the extra dimension added by _prepare_tensor_for_comm if it was singleton.
Handles both single tensors and lists of tensors (for VPP).

Args:
tensor: Input tensor (3D with singleton last dim), list of tensors, or None.

Returns:
2D tensor (if last dim was singleton), list of tensors, or None.
"""
if tensor is None:
return None
if isinstance(tensor, list):
return [_restore_tensor_from_comm(t) for t in tensor]
if isinstance(tensor, torch.Tensor) and tensor.ndim == 3 and tensor.shape[-1] == 1:
return tensor.squeeze(-1)
return tensor


class MultiModulePipelineCommunicator:
"""Communicator for a multi-module pipeline."""

Expand Down Expand Up @@ -266,12 +312,14 @@ def recv_forward(
# If first stage, and has incoming modules, receive forward activation
# from incoming modules.
for bridge_comm in rank_module_info.bridge_comms_as_dest_module:
input_dict[bridge_comm.src_module_name] = bridge_comm.recv_forward()
received_tensor = bridge_comm.recv_forward()
input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(received_tensor)
else:
# If not first stage, receive forward activation tensor from P2P communicator.
input_dict[module_name] = rank_module_info.p2p_communicator.recv_forward(
received_tensor = rank_module_info.p2p_communicator.recv_forward(
tensor_shapes=tensor_shape, is_first_stage=False
)
input_dict[module_name] = _restore_tensor_from_comm(received_tensor)
return input_dict

def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool = False):
Expand All @@ -280,20 +328,18 @@ def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool
Args:
output_dict: A dictionary mapping module names to tensors.
"""
logging.debug(
f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] "
f"[send_forward] output_dict keys: {output_dict.keys()}, is_last_stage: {is_last_stage}"
)
for module_name, rank_module_info in self.rank_module_map.items():
if rank_module_info.pp_rank == rank_module_info.pp_size - 1:
# If last stage, and has outgoing modules, send forward activation
# by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_src_module:
bridge_comm.send_forward(output_dict[module_name])
tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name])
bridge_comm.send_forward(tensor_to_send)
else:
# If not last stage, send forward activation by using P2P communicator.
tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name])
rank_module_info.p2p_communicator.send_forward(
output_dict[module_name], is_last_stage=False
tensor_to_send, is_last_stage=False
)

def send_forward_recv_backward(
Expand All @@ -311,28 +357,23 @@ def send_forward_recv_backward(
Returns:
A dictionary mapping module names to tensors.
"""
logging.debug(
f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] "
f"[send_forward_recv_backward] output_dict keys: {output_dict.keys()}, "
f"tensor_shape: {tensor_shape}, is_last_stage: {is_last_stage}"
)
grad_dict = {}
for module_name, rank_module_info in self.rank_module_map.items():
if rank_module_info.pp_rank == rank_module_info.pp_size - 1:
# If last stage, and has outgoing modules, send forward activation and
# receive backward gradient by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_src_module:
grad_dict[bridge_comm.src_module_name] = bridge_comm.send_forward_recv_backward(
output_dict[module_name]
)
tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name])
grad = bridge_comm.send_forward_recv_backward(tensor_to_send)
grad_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(grad)
else:
# If not last stage, send forward activation and receive backward gradient
# by using P2P communicator.
grad_dict[module_name] = (
rank_module_info.p2p_communicator.send_forward_recv_backward(
output_dict[module_name], tensor_shapes=tensor_shape, is_last_stage=False
)
tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name])
grad = rank_module_info.p2p_communicator.send_forward_recv_backward(
tensor_to_send, tensor_shapes=tensor_shape, is_last_stage=False
)
grad_dict[module_name] = _restore_tensor_from_comm(grad)
return grad_dict

def send_backward_recv_forward(
Expand All @@ -350,30 +391,23 @@ def send_backward_recv_forward(
Returns:
A dictionary mapping module names to tensors.
"""
logging.debug(
f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] "
f"[send_backward_recv_forward] grad_dict keys: {grad_dict.keys()}, "
f"tensor_shape: {tensor_shape}, is_first_stage: {is_first_stage}"
)
input_dict = {}
for module_name, rank_module_info in self.rank_module_map.items():
if rank_module_info.pp_rank == 0:
for bridge_comm in rank_module_info.bridge_comms_as_dest_module:
# If first stage, and has incoming modules, send backward gradient and
# receive forward activation by using bridge communicator.
input_dict[bridge_comm.src_module_name] = (
bridge_comm.send_backward_recv_forward(
grad_dict[bridge_comm.src_module_name]
)
)
grad_to_send = _prepare_tensor_for_comm(grad_dict[bridge_comm.src_module_name])
received_tensor = bridge_comm.send_backward_recv_forward(grad_to_send)
input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(received_tensor)
else:
# If not first stage, send backward gradient and receive forward activation
# by using P2P communicator.
input_dict[module_name] = (
rank_module_info.p2p_communicator.send_backward_recv_forward(
grad_dict[module_name], tensor_shapes=tensor_shape, is_first_stage=False
)
grad_to_send = _prepare_tensor_for_comm(grad_dict[module_name])
received_tensor = rank_module_info.p2p_communicator.send_backward_recv_forward(
grad_to_send, tensor_shapes=tensor_shape, is_first_stage=False
)
input_dict[module_name] = _restore_tensor_from_comm(received_tensor)
return input_dict

def recv_backward(
Expand All @@ -397,12 +431,14 @@ def recv_backward(
# If last stage, and has incoming modules, receive backward gradient
# by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_src_module:
grad_dict[bridge_comm.src_module_name] = bridge_comm.recv_backward()
grad = bridge_comm.recv_backward()
grad_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(grad)
else:
# If not last stage, receive backward gradient by using P2P communicator.
grad_dict[module_name] = rank_module_info.p2p_communicator.recv_backward(
grad = rank_module_info.p2p_communicator.recv_backward(
tensor_shapes=tensor_shape, is_last_stage=False
)
grad_dict[module_name] = _restore_tensor_from_comm(grad)
return grad_dict

def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool = False):
Expand All @@ -411,20 +447,18 @@ def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool
Args:
grad_dict: A dictionary mapping module names to tensors.
"""
logging.debug(
f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] "
f"[send_backward] grad_dict keys: {grad_dict.keys()}, is_first_stage: {is_first_stage}"
)
for module_name, rank_module_info in self.rank_module_map.items():
if rank_module_info.pp_rank == 0:
# If first stage, and has incoming modules, send backward activation
# by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_dest_module:
bridge_comm.send_backward(grad_dict[bridge_comm.src_module_name])
grad_to_send = _prepare_tensor_for_comm(grad_dict[bridge_comm.src_module_name])
bridge_comm.send_backward(grad_to_send)
else:
# If not first stage, send backward activation by using P2P communicator.
grad_to_send = _prepare_tensor_for_comm(grad_dict[module_name])
rank_module_info.p2p_communicator.send_backward(
grad_dict[module_name], is_first_stage=False
grad_to_send, is_first_stage=False
)

@staticmethod
Expand Down
16 changes: 16 additions & 0 deletions megatron/core/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.distributed as dist

from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
from megatron.core.utils import nvtx_decorator

# Types
Expand Down Expand Up @@ -162,6 +163,21 @@ def __init__(self, pp_group: dist.ProcessGroup, config: ModelParallelConfig):
else None
)

@property
def is_pp_first_stage(self) -> bool:
"""Return True if pp first stage."""
return is_pp_first_stage(self.pp_group)

@property
def is_pp_last_stage(self) -> bool:
"""Return True if pp last stage."""
return is_pp_last_stage(self.pp_group)

@property
def num_warmup_microbatches(self) -> int:
Comment thread
jaredcasper marked this conversation as resolved.
Outdated
"""Return number of warmup microbatches."""
return self.pp_group.size() - self.pp_group.rank() - 1

def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, recv_next):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
Expand Down
Loading
Loading