@@ -399,10 +399,7 @@ def backward(ctx, grad_x, grad_probs, grad_scaling_factor, grad_tokens_per_exper
399399 '''
400400 handle = ctx .handle
401401 combined_hidden , combined_probs = _hybrid_ep_buffer .combine_with_unpermute (
402- hidden = grad_x ,
403- probs = grad_probs ,
404- handle = handle ,
405- pad_multiple = ctx .pad_multiple ,
402+ hidden = grad_x , probs = grad_probs , handle = handle , pad_multiple = ctx .pad_multiple
406403 )
407404 return combined_hidden , None , combined_probs , None , None , None , None , None , None , None
408405
@@ -413,16 +410,12 @@ class HybridEPCombine(torch.autograd.Function):
413410 '''
414411
415412 @staticmethod
416- def forward (
417- ctx , x , handle , num_permuted_tokens = None , pad_multiple = None
418- ):
413+ def forward (ctx , x , handle , num_permuted_tokens = None , pad_multiple = None ):
419414 '''
420415 Forward pass of fused combine of the HybridEP backend
421416 '''
422417 combined_hidden , _ = _hybrid_ep_buffer .combine_with_unpermute (
423- hidden = x ,
424- handle = handle ,
425- pad_multiple = pad_multiple ,
418+ hidden = x , handle = handle , pad_multiple = pad_multiple
426419 )
427420 ctx .handle = handle
428421 ctx .pad_multiple = pad_multiple
@@ -514,9 +507,7 @@ def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple):
514507 The alignment multiple required for FP8 GEMM. If not provided, no padding
515508 is performed.
516509 '''
517- return HybridEPCombine .apply (
518- x , handle , num_permuted_tokens , pad_multiple
519- )
510+ return HybridEPCombine .apply (x , handle , num_permuted_tokens , pad_multiple )
520511
521512else :
522513 hybrid_ep_dispatch = None
0 commit comments