Skip to content

Commit 942825f

Browse files
rootAutumn1998
authored andcommitted
format
1 parent 6ceac69 commit 942825f

1 file changed

Lines changed: 4 additions & 13 deletions

File tree

megatron/core/transformer/moe/fused_a2a.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,7 @@ def backward(ctx, grad_x, grad_probs, grad_scaling_factor, grad_tokens_per_exper
399399
'''
400400
handle = ctx.handle
401401
combined_hidden, combined_probs = _hybrid_ep_buffer.combine_with_unpermute(
402-
hidden=grad_x,
403-
probs=grad_probs,
404-
handle=handle,
405-
pad_multiple=ctx.pad_multiple,
402+
hidden=grad_x, probs=grad_probs, handle=handle, pad_multiple=ctx.pad_multiple
406403
)
407404
return combined_hidden, None, combined_probs, None, None, None, None, None, None, None
408405

@@ -413,16 +410,12 @@ class HybridEPCombine(torch.autograd.Function):
413410
'''
414411

415412
@staticmethod
416-
def forward(
417-
ctx, x, handle, num_permuted_tokens=None, pad_multiple=None
418-
):
413+
def forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None):
419414
'''
420415
Forward pass of fused combine of the HybridEP backend
421416
'''
422417
combined_hidden, _ = _hybrid_ep_buffer.combine_with_unpermute(
423-
hidden=x,
424-
handle=handle,
425-
pad_multiple=pad_multiple,
418+
hidden=x, handle=handle, pad_multiple=pad_multiple
426419
)
427420
ctx.handle = handle
428421
ctx.pad_multiple = pad_multiple
@@ -514,9 +507,7 @@ def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple):
514507
The alignment multiple required for FP8 GEMM. If not provided, no padding
515508
is performed.
516509
'''
517-
return HybridEPCombine.apply(
518-
x, handle, num_permuted_tokens, pad_multiple
519-
)
510+
return HybridEPCombine.apply(x, handle, num_permuted_tokens, pad_multiple)
520511

521512
else:
522513
hybrid_ep_dispatch = None

0 commit comments

Comments
 (0)