Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.ci.dev
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ RUN bash -ex <<"EOF"

git clone --branch hybrid-ep https://github.com/deepseek-ai/DeepEP.git
pushd DeepEP
git checkout 1dddd194c26911c35b4f53a148617dd73de0ffc9
git checkout 83e0d156807f31abed4ea55c2fa6eb4b62a11b82
patch -p1 < /workspace/deepep.patch
popd
TORCH_CUDA_ARCH_LIST="9.0 10.0 12.0" uv pip install --no-build-isolation -v DeepEP/.
Expand Down
51 changes: 14 additions & 37 deletions megatron/core/transformer/moe/fused_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE

from megatron.core.utils import internal_api

try:
from deep_ep import Buffer
Expand Down Expand Up @@ -328,6 +329,7 @@ def reset_hybrid_ep_buffer():
_hybrid_ep_buffer = None


@internal_api
class HybridEPDispatch(torch.autograd.Function):
'''
Fused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend
Expand All @@ -343,7 +345,6 @@ def forward(
num_local_experts,
num_sms_dispatch_api=24,
num_sms_combine_api=24,
num_dispatched_tokens=None,
num_permuted_tokens=None,
pad_multiple=None,
):
Expand All @@ -362,11 +363,9 @@ def forward(
num_sms_combine_api,
fp8_dispatch,
)
# Defaultly, the output token_per_expert and num_dispatched_tokens_tensor
# will be put on the CPU to avoid the potential sync in combine/backward pass,
# but if we provide the num_dispatched_tokens and num_permuted_tokens on CPU,
# we do not need to the D2H here.
use_host_meta = num_dispatched_tokens is None or num_permuted_tokens is None
# If we provide the num_permuted_tokens, we do not need to use sync to
# wait for the data in pinned memory ready
non_blocking = num_permuted_tokens is not None
# Process the dispatch
(
dispatched_hidden,
Expand All @@ -381,14 +380,12 @@ def forward(
scaling_factor=None,
num_of_experts_per_rank=num_local_experts,
pad_multiple=pad_multiple,
num_dispatched_tokens=num_dispatched_tokens,
num_permuted_tokens=num_permuted_tokens,
use_host_meta=use_host_meta,
non_blocking=non_blocking,
)

ctx.handle = handle
ctx.pad_multiple = pad_multiple
ctx.num_dispatched_tokens = num_dispatched_tokens
return (
dispatched_hidden,
dispatched_probs,
Expand All @@ -404,36 +401,27 @@ def backward(ctx, grad_x, grad_probs, grad_scaling_factor, grad_tokens_per_exper
'''
handle = ctx.handle
combined_hidden, combined_probs = _hybrid_ep_buffer.combine_with_unpermute(
hidden=grad_x,
probs=grad_probs,
handle=handle,
pad_multiple=ctx.pad_multiple,
num_dispatched_tokens=ctx.num_dispatched_tokens,
hidden=grad_x, probs=grad_probs, handle=handle, pad_multiple=ctx.pad_multiple
)
return combined_hidden, None, combined_probs, None, None, None, None, None, None, None


@internal_api
class HybridEPCombine(torch.autograd.Function):
'''
Fused combine operation for permute + combine a2a + permute using the HybridEP backend
'''

@staticmethod
def forward(
ctx, x, handle, num_dispatched_tokens=None, num_permuted_tokens=None, pad_multiple=None
):
def forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None):
'''
Forward pass of fused combine of the HybridEP backend
'''
combined_hidden, _ = _hybrid_ep_buffer.combine_with_unpermute(
hidden=x,
handle=handle,
pad_multiple=pad_multiple,
num_dispatched_tokens=num_dispatched_tokens,
hidden=x, handle=handle, pad_multiple=pad_multiple
)
ctx.handle = handle
ctx.pad_multiple = pad_multiple
ctx.num_dispatched_tokens = num_dispatched_tokens
ctx.num_permuted_tokens = num_permuted_tokens
return combined_hidden

Expand All @@ -448,14 +436,14 @@ def backward(ctx, grad_x):
scaling_factor=None,
handle=handle,
pad_multiple=ctx.pad_multiple,
num_dispatched_tokens=ctx.num_dispatched_tokens,
num_permuted_tokens=ctx.num_permuted_tokens,
)
return dispatched_hidden, None, None, None, None


if HAVE_HYBRIDEP:

@internal_api
def hybrid_ep_dispatch(
x,
routing_map,
Expand All @@ -464,7 +452,6 @@ def hybrid_ep_dispatch(
num_local_experts,
num_sms_dispatch_api=24,
num_sms_combine_api=24,
num_dispatched_tokens=None,
num_permuted_tokens=None,
pad_multiple=None,
):
Expand All @@ -487,10 +474,6 @@ def hybrid_ep_dispatch(
Number of SMs used by the dispatch API.
num_sms_combine_api (int):
Number of SMs used by the combine API.
num_dispatched_tokens (int):
Number of tokens after dispatch but before permute. HybridEP uses this
to allocate buffers. If not provided, HybridEP obtains the size from
a GPU tensor, which causes a D2H synchronization.
num_permuted_tokens (int):
Number of tokens after permute. HybridEP uses this to allocate buffers.
If not provided, HybridEP obtains the size from a GPU tensor,
Expand All @@ -507,12 +490,12 @@ def hybrid_ep_dispatch(
num_local_experts,
num_sms_dispatch_api,
num_sms_combine_api,
num_dispatched_tokens,
num_permuted_tokens,
pad_multiple,
)

def hybrid_ep_combine(x, handle, num_dispatched_tokens, num_permuted_tokens, pad_multiple):
@internal_api
def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple):
'''
Perform fused combine operation for unpermute + combine a2a + unpermute
using the HybridEP backend
Expand All @@ -522,20 +505,14 @@ def hybrid_ep_combine(x, handle, num_dispatched_tokens, num_permuted_tokens, pad
Input hidden states to combine
handle (EventHandle):
Communication handle from dispatch operation
num_dispatched_tokens (int):
The number of tokens after unpermute but before combine. HybridEP uses this
to allocate buffers. If not provided, HybridEP obtains the size from a GPU tensor,
which causes a D2H synchronization.
num_permuted_tokens (int): The number of tokens before unpermute. HybridEP uses this
to allocate buffers. If not provided, HybridEP obtains the size from a GPU tensor,
which causes a D2H synchronization.
pad_multiple (int):
The alignment multiple required for FP8 GEMM. If not provided, no padding
is performed.
'''
return HybridEPCombine.apply(
x, handle, num_dispatched_tokens, num_permuted_tokens, pad_multiple
)
return HybridEPCombine.apply(x, handle, num_permuted_tokens, pad_multiple)

else:
hybrid_ep_dispatch = None
Expand Down
15 changes: 3 additions & 12 deletions megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,11 +985,8 @@ def __init__(
if self.drop_and_pad:
assert self.capacity_factor is not None
self.capacity = None
# The up-bound for the number of tokens after dispatch op, -1 means no up-bound,
# which will cause a CPU sync
self.num_dispatched_tokens = None
# Actually the sum of tokens_per_expert, the up-bound for the number of tokens
# after permute op, -1 means no up-bound, will cause a CPU sync
# Actually the the up-bound for the number of tokens
# after permute op, None means no up-bound, will cause a CPU sync
self.num_permuted_tokens = None

# Metadata
Expand Down Expand Up @@ -1018,12 +1015,9 @@ def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor):
num_experts=self.num_experts,
capacity_factor=self.capacity_factor,
)
# We cannot predict the actual number of tokens after the dispatch op,
# so we set it to the worst case in drop_and_pad mode
self.num_dispatched_tokens = self.capacity * self.group.size() * self.num_local_experts
# In drop_and_pad mode, the number of tokens after the permute op
# can be computed on the CPU
self.num_permuted_tokens = self.num_dispatched_tokens
self.num_permuted_tokens = self.capacity * self.group.size() * self.num_local_experts
self.tokens_per_expert = torch.full(
(self.num_local_experts,), self.capacity * self.group.size(), dtype=torch.long
)
Expand Down Expand Up @@ -1052,7 +1046,6 @@ def dispatch(
num_local_experts=self.num_local_experts,
num_sms_dispatch_api=self.config.moe_hybridep_num_sms,
num_sms_combine_api=self.config.moe_hybridep_num_sms,
num_dispatched_tokens=self.num_dispatched_tokens,
num_permuted_tokens=self.num_permuted_tokens,
pad_multiple=self.pad_multiple,
)
Expand All @@ -1074,7 +1067,6 @@ def combine(
hidden_states = hybrid_ep_combine(
x=hidden_states,
handle=self.handle,
num_dispatched_tokens=self.num_dispatched_tokens,
num_permuted_tokens=self.num_permuted_tokens,
pad_multiple=self.pad_multiple,
)
Expand All @@ -1084,7 +1076,6 @@ def combine(
self.handle = None
if not self.drop_and_pad:
self.num_permuted_tokens = None
self.num_dispatched_tokens = None
return hidden_states

def get_permuted_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down
Loading