MoE Dispatch/Combine first implementation for k=1 and Nccl backend#5857
MoE Dispatch/Combine first implementation for k=1 and Nccl backend#5857samnordmann wants to merge 9 commits intomainfrom
k=1 and Nccl backend#5857Conversation
k=1 and Nccl backendk=1 and Nccl backend
|
!test |
Greptile OverviewGreptile SummaryThis PR adds an initial multi-device MoE dispatch+combine path for topk=1 backed by NCCL all-to-all. It introduces new IR nodes ( The overall design is: Merge blockers are mostly around correctness/contracts at the boundaries (dtype/shape invariants and scatter semantics) rather than the general wiring. Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant A as Evaluator
participant B as Dispatch
participant C as Combine
A->>B: dispatch step
B-->>A: outputs
A->>C: combine step
C-->>A: outputs
|
| // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we | ||
| // sync/copy here. GPU-initiated comms can avoid this extra sync. | ||
| auto rank_for_token_cpu = rank_for_token.to(at::kCPU); | ||
| auto n_tokens_to_rank_cpu = | ||
| at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); | ||
| auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); |
There was a problem hiding this comment.
style: CPU synchronization here adds overhead by transferring rank_for_token to CPU, computing bincount, then moving back to GPU. For large token counts, this synchronization could become a bottleneck.
| // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we | |
| // sync/copy here. GPU-initiated comms can avoid this extra sync. | |
| auto rank_for_token_cpu = rank_for_token.to(at::kCPU); | |
| auto n_tokens_to_rank_cpu = | |
| at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); | |
| auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); | |
| // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we | |
| // sync/copy here. GPU-initiated comms can avoid this extra sync. | |
| // TODO: Consider using GPU-based bincount to avoid CPU synchronization overhead. | |
| auto rank_for_token_cpu = rank_for_token.to(at::kCPU); | |
| auto n_tokens_to_rank_cpu = | |
| at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Review updated until commit a0de605 Description
|
| Relevant files | |||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement | 7 files
| ||||||||||||||
| Tests | 1 files
| ||||||||||||||
| Configuration changes | 1 files
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Memory allocation optimization
|
Test failures
-
(Medium, 3)
Shape mismatch in thunder.fx higher-order inplace alias update tests (test_update_aliases)Test Name A100 GB200 H100 Source thunder.tests.test_update_aliases.test_higher_order_inplace_alias_update_nvfuser_cuda_thunder.dtypes.float32 ❌ ❌ ❌ -
(Medium, 1)
Thunder nvFuser scalar mismatch in nanoGPT autograd CUDA test_networksTest Name GB200 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌
wujingyue
left a comment
There was a problem hiding this comment.
Thanks!
First round -- I'll review combine later.
| TensorView* out_n_tokens_from_rank, | ||
| TensorView* in_x, | ||
| TensorView* in_topk_idx, | ||
| TensorView* in_topk_weights, |
There was a problem hiding this comment.
IIUC, topk_weights doesn't need to go through dispatch or combine. Llama4 applies the topk weights before dispatch
Fuser/tests/python/test_moe.py
Line 114 in 0c147ae
There was a problem hiding this comment.
Agreed — dropped topk_weights. Callers should apply weights before dispatch or after combine; added comments to document that
There was a problem hiding this comment.
There was a problem hiding this comment.
There was a problem hiding this comment.
add back topk_weights
| NVF_CHECK( | ||
| [&]() { | ||
| auto token_counts = is_token_in_rank.to(at::kLong).sum(1); | ||
| auto min_val = token_counts.min().item<int64_t>(); | ||
| auto max_val = token_counts.max().item<int64_t>(); | ||
| return min_val == 1 && max_val == 1; | ||
| }(), | ||
| "Only topk=1 is supported. Each token must be assigned to exactly one " | ||
| "rank."); |
There was a problem hiding this comment.
style: Validation check causes unnecessary GPU-CPU synchronization on every dispatch call. The lambda executes .min() and .max() which trigger implicit CPU synchronization via .item<int64_t>(). For topk=1 validation, consider validating this constraint earlier during graph construction or accepting it as a precondition, rather than checking it at runtime on the hot path.
0f48cd5 to
afd948d
Compare
|
!test |
| TensorView* out_topk_idx, | ||
| TensorView* out_src_idx, | ||
| TensorView* out_src_rank, |
There was a problem hiding this comment.
I understand n_tokens_to_rank and n_tokens_from_rank, but I'm not sure why these three tensors need to be outputs. At least, I didn't have to use them in
Fuser/tests/python/multidevice/test_expert_parallel.py
Lines 165 to 185 in f8b6785
There was a problem hiding this comment.
They’re required for the round‑trip semantics here. out_src_idx/out_src_rank are the metadata needed by combine to route tokens back to their original rank and restore local order (we use them to build the alltoall splits and the final index_copy_).
out_topk_idx is still needed after dispatch so the rank can route tokens to its local experts.
about out_src_idx/out_src_rank, I can go with your suggestion and drop them, but I want to be explicit about the implication: we’d be committing to the same constrained ordering as the reference test you linked, that is, routing must be fully determined by split sizes and a fixed rank<->expert mapping, with per‑rank chunk order preserved by all‑to‑all. In that model you can reconstruct order without per‑token metadata. The trade‑off is that we lose support for arbitrary routing/non‑trivial meshes or any custom per‑token permutation, since we’d have no way to restore original order. One potential issue I foresee is implementing a padded dispatch/combined on some over- and pre-allocated buffers -- but that could be solved later!
Assuming we are ok with those implications, I’ll proceed with the simplification.
There was a problem hiding this comment.
Coming back on my last comment: note that DeepSeek API seems to use those tensors and their Combine specifically reorders back to original token positions (using src_idx/src_rank).
https://github.com/deepseek-ai/DeepEP/blob/29d31c095796f3c8ece47ee9cdcc167051bbeed9/csrc/kernels/intranode.cu#L1034
Is this an argument to keep those variables?
As a side note, related to #5857 (comment), note that DeepSeek interface takes topk_weights
There was a problem hiding this comment.
Any thoughts on that? Imho reducing the number of arguments doesn't necessarily simplify the IR because it introduces implicit assumptions
There was a problem hiding this comment.
note that DeepSeek interface takes topk_weights
Good point. I suspect this is because their combine kernel also covers the topk_weights multiplication. It's suboptimal to materialize s*k/ep in global memory. This PR probably doesn't need that at this very moment because it assumes k=1. I'll double check the code to make sure.
What does your kernel look like? I believe this is just a reference implementation and you are about to add a fused kernel.
There was a problem hiding this comment.
I'll double check the code to make sure.
I got a chance to read the code. Thanks for the pointer! There are three modes:
- Low-latency (see internode_ll.cu). Inter-node. One-hop.
- Intra-node (see intranode.cu). One-hop.
- Inter-node (see internode.cu). Two-hop, first RDMA and then nvlink.
Regardless of which mode, combine reads topk_weights. When k>1, it's better to fuse topk_weights multiplication into the combine kernel to avoid materializing s*k/ep tokens in global memory.
According to deepseek-ai/DeepEP#72, dispatch needs to produce topk_weights for local reduction. Therefore, dispatch does so for inter-node and doesn't for low-latency. It's yet unclear why dispatch produces topk_weights for intra-node in some cases.
What kernel? I am not sure to understand. I am not writing any cuda kernel in this PR
Never mind. I was referring to the kernel in the other PR you showed on Thursday. I wasn't sure you were building kernels from scratch or reusing DeepEP's. Since the former, we have more motivations to keep implementations simple until complexity is required.
There was a problem hiding this comment.
out_src_idx makes sense to me now. I previously thought a local sort was performed before dispatch, but I now understand that dispatch itself handles that.
There was a problem hiding this comment.
I'm still not sure about out_src_rank. Does it equal out_src_idx / ceil(seq_len / EP)?
There was a problem hiding this comment.
I'll double check the code to make sure.
I got a chance to read the code. Thanks for the pointer! There are three modes:
- Low-latency (see internode_ll.cu). Inter-node. One-hop.
- Intra-node (see intranode.cu). One-hop.
- Inter-node (see internode.cu). Two-hop, first RDMA and then nvlink.
Regardless of which mode, combine reads topk_weights. When k>1, it's better to fuse topk_weights multiplication into the combine kernel to avoid materializing
s*k/eptokens in global memory.
Yes, this is an important optimization. Therefore keeping topk_weights in the interface so in the next PR this will be fused with the combine's communication. Ok?
According to deepseek-ai/DeepEP#72, dispatch needs to produce topk_weights for local reduction. Therefore, dispatch does so for inter-node and doesn't for low-latency. It's yet unclear why dispatch produces
topk_weightsfor intra-node in some cases.
IIUC, this comment discusses whether topk_weights should be present in dispatch interface. The answer is that it is not necessary because topk_weights is only needed by combineafter receiving the data, when the tokens have been reordered to their original order (so each rank keeps its topk_weights unsorted and uses them only after the tokens are back at their original rank and order).
However, I disagree with this answer. The optimal implementation is to fuse combine and the reduction, as you suggested just above, not "combine" followed by "local reduction". We do not need to materialize all the experts activations at the receiver side, only the weighted average. This is equivalent to comparing allreduce and allgather+local-reduce. So we need the weights at the sender side, before reordering the tokens.
My conclusion is that we should keep topk_weights both in combine and dispatch. Do you agree?
out_src_idxmakes sense to me now. I previously thought a local sort was performed before dispatch, but I now understand that dispatch itself handles that.
Yes, it would be too rigid and suboptimal to request the tokens to be sorted before dispatch -- dispatch would merely be an MPI-alltoallv (for k=1). We do not need to reorder the tokens, especially when using raw nvLink transport as in the next PR. In the communication kernel, we can account for discountinuity in the source buffer. That's actually crucial for latency and BW to avoid repacking the data at the sender side.
I'm still not sure about
out_src_rank. Does it equalout_src_idx / ceil(seq_len / EP)?
I don't think so. Tokens are not guaranteed to be evenly dispatched to experts.
Note that out_src_idx is the index local to the rank. The "global" index of the token is deduced from out_src_idx and out_src_rank.
There was a problem hiding this comment.
My conclusion is that we should keep topk_weights both in combine and dispatch. Do you agree?
Yes, I agree now. IIUC, even for one-hop combines, it's still better to apply topk_weights on the sender side because the reduction can be done in fabric.
https://github.com/deepseek-ai/DeepEP/blob/29d31c095796f3c8ece47ee9cdcc167051bbeed9/csrc/kernels/internode_ll.cu#L1078 might just be suboptimal because it requires topk_weight to be in sequence order. Could it be an RDMA-related constraint? Note this kernel is for inter-nvlink-domain.
Note that out_src_idx is the index local to the rank.
Got it -- I missed that.
6da9793 to
4693c53
Compare
| // Asymmetric example: | ||
| // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. | ||
| auto rank_ids = at::arange(world_size, int_options); | ||
| auto token_rank = at::tensor({0, 1, 1, 1}, int_options); |
There was a problem hiding this comment.
hardcoded to 4 tokens with routing [0, 1, 1, 1] - won't work for world_size > 2
| auto topk_idx_flat = topk_idx.reshape({num_tokens}); | ||
|
|
||
| // Determine destination rank per token (topk=1). | ||
| auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); |
There was a problem hiding this comment.
assumes is_token_in_rank is valid one-hot (exactly one 1 per row). If multiple ranks are marked or none are marked, argmax returns first/last index but routing will be incorrect. Add validation that each row has exactly one True value.
| topk_is_1d || topk_is_2d, | ||
| "Only topk=1 supported. topk_idx must be shape [T] or [T, 1], got: ", | ||
| topk_idx.sizes()); | ||
| auto topk_idx_flat = topk_idx.reshape({num_tokens}); |
There was a problem hiding this comment.
no validation that topk_idx values are within [0, num_experts) range. Invalid expert IDs could cause incorrect rank assignment or out-of-bounds issues in local expert computation.
|
!test |
|
!test |
| recv_src_rank, send_src_rank, output_splits, input_splits)); | ||
|
|
||
| // Locally reorder by expert id so each rank processes contiguous experts. | ||
| auto local_expert = recv_topk_idx - my_rank * experts_per_rank; |
There was a problem hiding this comment.
Potential out-of-bounds risk: local_expert = recv_topk_idx - my_rank * experts_per_rank could produce negative values if recv_topk_idx contains expert IDs that don't belong to this rank. While the subsequent argsort will still work, it may produce unexpected results during expert processing if the assumption is that only local expert IDs should be received.
| // Asymmetric example: | ||
| // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. | ||
| auto rank_ids = at::arange(world_size, int_options); | ||
| auto token_rank = at::tensor({0, 1, 1, 1}, int_options); |
There was a problem hiding this comment.
Hardcoded routing pattern [0, 1, 1, 1] only works for world_size=2. For world_size > 2, this creates is_token_in_rank with more columns than the routing pattern accounts for, causing rank 2+ to receive no tokens. Consider making the test routing scale with world_size or adding a GTEST_SKIP() for world_size > 2.
| int64_t sumSplitSizes(const std::vector<int64_t>& splits) { | ||
| int64_t total = 0; | ||
| for (auto value : splits) { | ||
| total += value; | ||
| } | ||
| return total; |
There was a problem hiding this comment.
Missing overflow check for large tensors. Summing split sizes without overflow checking could cause issues with very large token counts. Consider using checked arithmetic or validating that the sum doesn't exceed reasonable bounds.
| auto topk_idx_flat = topk_idx.reshape({num_tokens}); | ||
|
|
||
| // Determine destination rank per token (topk=1). | ||
| auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); |
There was a problem hiding this comment.
The .to(at::kLong) conversions are redundant. is_token_in_rank is already cast to at::kLong on the same line, and argmax(1) returns int64 by default.
|
!test |
| std::vector<int64_t> toSplitSizes(const at::Tensor& sizes_tensor) { | ||
| auto cpu_sizes = sizes_tensor.to(at::kCPU); | ||
| auto* ptr = cpu_sizes.data_ptr<int64_t>(); | ||
| return std::vector<int64_t>(ptr, ptr + cpu_sizes.numel()); | ||
| } |
There was a problem hiding this comment.
Unsafe dtype assumption
toSplitSizes unconditionally does cpu_sizes.data_ptr<int64_t>() after .to(at::kCPU), but it never checks that sizes_tensor is actually kLong/int64. If a caller passes an int32 count tensor (e.g., from a different backend or future refactor), this will throw (data_ptr<int64_t>() called on ...) at runtime. Consider asserting cpu_sizes.scalar_type() == at::kLong (or converting) before taking the pointer.
| auto* recv_x = makeSymbolicTensor(2); | ||
| auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); | ||
| auto* recv_topk_weights = makeSymbolicTensor(1); | ||
| auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); | ||
| auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); | ||
| auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); | ||
| auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); |
There was a problem hiding this comment.
IR dtype mismatch
This test declares recv_topk_idx, recv_src_idx, recv_src_rank, n_tokens_to_rank, and n_tokens_from_rank as DataType::Int, but the implementation returns/uses at::kLong for these tensors (e.g., send_src_idx = sorted_indices.to(at::kLong) and bincount(...).to(at::kLong)). On platforms where DataType::Int maps to 32-bit, Host IR validation/binding will fail. The symbolic tensor dtypes should match the actual tensors produced (likely DataType::Int64).
| auto combined_x = at::empty({total_recv, hidden}, x.options()); | ||
| combined_x.index_copy_(0, recv_src_idx, recv_x); | ||
| auto combined_topk_weights = | ||
| at::empty({total_recv}, topk_weights_flat.options()); | ||
| combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); |
There was a problem hiding this comment.
Index collision drops data
combined_x.index_copy_(0, recv_src_idx, recv_x) (and same for combined_topk_weights) will silently overwrite when recv_src_idx contains duplicates (possible if upstream routing produces repeated source indices). That yields incorrect combined outputs with no error. If uniqueness is an invariant, it needs to be enforced/validated before index_copy_ so failures are explicit.
|
|
||
| struct CombineResult { | ||
| at::Tensor combined_x; // Combined tokens back in original order. | ||
| at::Tensor combined_topk_weights; // Combined gating weights per token. |
There was a problem hiding this comment.
should be removed
Uh oh!
There was an error while loading. Please reload this page.