Skip to content

MoE Dispatch/Combine first implementation for k=1 and Nccl backend#5857

Open
samnordmann wants to merge 9 commits intomainfrom
dispatch_combine/stub
Open

MoE Dispatch/Combine first implementation for k=1 and Nccl backend#5857
samnordmann wants to merge 9 commits intomainfrom
dispatch_combine/stub

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Jan 21, 2026

  • Add a first working dispatch+combine primitive for k=1 in multidevice execution, including utilities.
  • Extend Host IR evaluator plumbing to drive the new dispatch+combine path.
  • Add a C++ test

@samnordmann samnordmann changed the title MoE Dispatch Combine first implementation for k=1 and Nccl backend MoE Dispatch/Combine first implementation for k=1 and Nccl backend Jan 21, 2026
@samnordmann
Copy link
Collaborator Author

!test

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 21, 2026

Greptile Overview

Greptile Summary

This PR adds an initial multi-device MoE dispatch+combine path for topk=1 backed by NCCL all-to-all. It introduces new IR nodes (MoeDispatch, MoeCombine) in csrc/multidevice/communication.{h,cpp}, wires them into Host IR evaluation (csrc/host_ir/evaluator.{h,cpp}) by calling new runtime helpers (csrc/multidevice/dispatch_combine.{h,cpp}), and adds a new multidevice C++ test plus CMake integration.

The overall design is: MoeDispatch reorders tokens by destination rank, exchanges per-rank counts and token payloads with ProcessGroup::alltoall_base, and returns metadata (src_idx, src_rank, per-rank token counts) needed by MoeCombine. MoeCombine sends results back to source ranks and scatters by original token indices to restore local order.

Merge blockers are mostly around correctness/contracts at the boundaries (dtype/shape invariants and scatter semantics) rather than the general wiring.

Confidence Score: 3/5

  • This PR is close but has a couple correctness/contract issues to address before merging.
  • Core wiring (new IR nodes + HostIrEvaluator handlers + NCCL alltoall usage) is coherent, but there are concrete correctness hazards: (1) toSplitSizes assumes int64 dtype without checking/converting, so split-size extraction can throw at runtime if count tensors aren’t kLong; (2) the new test declares several symbolic tensors as DataType::Int while runtime uses at::kLong, which can break Host IR validation/binding depending on platform/type mapping; (3) combine uses index_copy_ which silently overwrites on duplicate indices unless uniqueness is enforced.
  • csrc/multidevice/dispatch_combine.cpp; tests/cpp/test_multidevice_dispatch_combine.cpp

Important Files Changed

Filename Overview
CMakeLists.txt Adds multidevice dispatch_combine.cpp and new multidevice test source to build; no functional issues spotted in the diff.
csrc/dispatch.h Registers new Expr op types MoeDispatch/MoeCombine in dispatch macros; appears consistent with dispatch usage.
csrc/host_ir/evaluator.h Adds HostIrEvaluator handlers for MoeDispatch/MoeCombine; declarations match new implementations.
csrc/host_ir/evaluator.cpp Implements HostIrEvaluator::handle for MoeDispatch/MoeCombine and binds outputs; relies on dispatch_combine runtime behavior.
csrc/multidevice/communication.h Introduces MoeDispatch/MoeCombine IR node declarations and accessors; types validated but runtime dtype expectations must match.
csrc/multidevice/communication.cpp Implements MoeDispatch/MoeCombine constructors/validate; no immediate issues, but depends on matching tensor dtypes in evaluator.
csrc/multidevice/dispatch_combine.h Adds doMoeDispatch/doMoeCombine API and result structs; contracts documented but some invariants not enforced in implementation.
csrc/multidevice/dispatch_combine.cpp Implements NCCL alltoall-based dispatch/combine for topk=1; issues: split-size dtype assumption in toSplitSizes and potential silent overwrite on duplicate recv_src_idx during combine.
tests/cpp/test_multidevice_dispatch_combine.cpp Adds dispatch+combine host IR test; currently declares several int tensors as DataType::Int while runtime uses at::kLong, risking dtype/validation failures.

Sequence Diagram

sequenceDiagram
  participant A as Evaluator
  participant B as Dispatch
  participant C as Combine
  A->>B: dispatch step
  B-->>A: outputs
  A->>C: combine step
  C-->>A: outputs
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +117 to +122
// For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we
// sync/copy here. GPU-initiated comms can avoid this extra sync.
auto rank_for_token_cpu = rank_for_token.to(at::kCPU);
auto n_tokens_to_rank_cpu =
at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong);
auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: CPU synchronization here adds overhead by transferring rank_for_token to CPU, computing bincount, then moving back to GPU. For large token counts, this synchronization could become a bottleneck.

Suggested change
// For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we
// sync/copy here. GPU-initiated comms can avoid this extra sync.
auto rank_for_token_cpu = rank_for_token.to(at::kCPU);
auto n_tokens_to_rank_cpu =
at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong);
auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device());
// For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we
// sync/copy here. GPU-initiated comms can avoid this extra sync.
// TODO: Consider using GPU-based bincount to avoid CPU synchronization overhead.
auto rank_for_token_cpu = rank_for_token.to(at::kCPU);
auto n_tokens_to_rank_cpu =
at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong);

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@github-actions
Copy link

github-actions bot commented Jan 21, 2026

Review updated until commit a0de605

Description

  • Add MoE dispatch and combine operations for multidevice execution with k=1 support

  • Implement NCCL backend communication primitives using alltoall for token routing

  • Extend Host IR evaluator with handlers for new dispatch/combine operations

  • Add comprehensive test coverage for dispatch/combine functionality

Changes walkthrough

Relevant files
Enhancement
7 files
dispatch_combine.cpp
Core dispatch/combine implementation with alltoall communication
+284/-0 
dispatch_combine.h
Header declarations for dispatch/combine functions and result structs
+125/-0 
communication.h
MoeDispatch and MoeCombine IR node class declarations       
+177/-0 
communication.cpp
MoeDispatch and MoeCombine IR node implementations with validation
+163/-0 
evaluator.cpp
Host IR evaluator handlers for MoeDispatch and MoeCombine operations
+61/-0   
evaluator.h
Handler declarations for MoeDispatch and MoeCombine in evaluator
+2/-0     
dispatch.h
Added MoeDispatch and MoeCombine to dispatch macro             
+2/-0     
Tests
1 files
test_multidevice_dispatch_combine.cpp
Comprehensive test for dispatch/combine functionality across ranks
+132/-0 
Configuration changes
1 files
CMakeLists.txt
Build system updates for new dispatch/combine source and test
+2/-0     

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Memory allocation optimization

The implementation currently allocates new buffers for all tensors in both dispatch and combine operations. Consider adding support for preallocated buffers as mentioned in the TODO comment on line 155, which could improve performance and reduce memory fragmentation in production workloads.

// TODO: support preallocated buffers.
Limited k=1 constraint

The implementation explicitly only supports k=1 (top-1 routing) as evidenced by the validation checks on lines 92-105. While this is acceptable for an initial implementation, consider documenting this limitation clearly and planning for k>1 support in future iterations.

const bool topk_is_1d = topk_idx.dim() == 1 && topk_idx.size(0) == num_tokens;
const bool topk_is_2d = topk_idx.dim() == 2 &&
    topk_idx.size(0) == num_tokens && topk_idx.size(1) == 1;
NVF_CHECK(
    topk_is_1d || topk_is_2d,
    "Only topk=1 supported. topk_idx must be shape [T] or [T, 1], got: ",
    topk_idx.sizes());
auto topk_idx_flat = topk_idx.reshape({num_tokens});
const bool weights_is_1d =
    topk_weights.dim() == 1 && topk_weights.size(0) == num_tokens;
const bool weights_is_2d = topk_weights.dim() == 2 &&
    topk_weights.size(0) == num_tokens && topk_weights.size(1) == 1;
NVF_CHECK(
    weights_is_1d || weights_is_2d,
    "Only topk=1 supported. topk_weights must be shape [T] or [T, 1], got: ",
    topk_weights.sizes());
auto topk_weights_flat = topk_weights.reshape({num_tokens});
Backend limitation

The code only supports NCCL backend (lines 133-135, 252-254). While NCCL is the primary backend for multi-GPU training, consider adding support for other backends or at least more flexible backend selection in future iterations.

backend,
CommunicatorBackend::kNccl,
"Only NCCL backend is supported for MoeDispatch.");

Test failures

  • (Medium, 3) Shape mismatch in thunder.fx higher-order inplace alias update tests (test_update_aliases)

    Test Name A100 GB200 H100 Source
    thunder.tests.test_update_aliases.test_higher_order_inplace_alias_update_nvfuser_cuda_thunder.dtypes.float32
  • (Medium, 1) Thunder nvFuser scalar mismatch in nanoGPT autograd CUDA test_networks

    Test Name GB200 Source
    thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32

@samnordmann samnordmann requested a review from wujingyue January 21, 2026 15:21
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

First round -- I'll review combine later.

TensorView* out_n_tokens_from_rank,
TensorView* in_x,
TensorView* in_topk_idx,
TensorView* in_topk_weights,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, topk_weights doesn't need to go through dispatch or combine. Llama4 applies the topk weights before dispatch

hidden_states = hidden_states * router_scores # [s, h]
and DeepSeek V3 does that after combine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed — dropped topk_weights. Callers should apply weights before dispatch or after combine; added comments to document that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add back topk_weights

#5857 (comment)

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 76 to 84
NVF_CHECK(
[&]() {
auto token_counts = is_token_in_rank.to(at::kLong).sum(1);
auto min_val = token_counts.min().item<int64_t>();
auto max_val = token_counts.max().item<int64_t>();
return min_val == 1 && max_val == 1;
}(),
"Only topk=1 is supported. Each token must be assigned to exactly one "
"rank.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Validation check causes unnecessary GPU-CPU synchronization on every dispatch call. The lambda executes .min() and .max() which trigger implicit CPU synchronization via .item<int64_t>(). For topk=1 validation, consider validating this constraint earlier during graph construction or accepting it as a precondition, rather than checking it at runtime on the hot path.

@samnordmann samnordmann force-pushed the dispatch_combine/stub branch from 0f48cd5 to afd948d Compare January 22, 2026 12:14
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann requested a review from wujingyue January 22, 2026 17:36
Comment on lines 195 to 197
TensorView* out_topk_idx,
TensorView* out_src_idx,
TensorView* out_src_rank,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand n_tokens_to_rank and n_tokens_from_rank, but I'm not sure why these three tensors need to be outputs. At least, I didn't have to use them in

# --------------------------------------------------------------------------
# Step 4: Each rank sorts the processed tokens by rank ID.
# --------------------------------------------------------------------------
# GPU 0: tokens_for_expert_0_from_rank_0 || tokens_for_expert_1_from_rank_0 || tokens_for_expert_2_from_rank_0 || tokens_for_expert_0_from_rank_1 || tokens_for_expert_1_from_rank_1 || tokens_for_expert_2_from_rank_1
# GPU 1: tokens_for_expert_3_from_rank_0 || tokens_for_expert_4_from_rank_0 || tokens_for_expert_5_from_rank_0 || tokens_for_expert_3_from_rank_1 || tokens_for_expert_4_from_rank_1 || tokens_for_expert_5_from_rank_1
processed_tokens_by_rank = expert_first_to_rank_first(
processed_tokens_by_expert, n_tokens_for_expert_from_rank
)
# --------------------------------------------------------------------------
# Step 5: Processed tokens are sent back to the original ranks.
# --------------------------------------------------------------------------
processed_tokens = torch.empty(n_tokens, dtype=torch.complex64, device="cuda")
# GPU 0: tokens_for_expert_0_from_rank_0 || tokens_for_expert_1_from_rank_0 || tokens_for_expert_2_from_rank_0 || tokens_for_expert_3_from_rank_0 || tokens_for_expert_4_from_rank_0 || tokens_for_expert_5_from_rank_0
# GPU 1: tokens_for_expert_0_from_rank_1 || tokens_for_expert_1_from_rank_1 || tokens_for_expert_2_from_rank_1 || tokens_for_expert_3_from_rank_1 || tokens_for_expert_4_from_rank_1 || tokens_for_expert_5_from_rank_1
dist.all_to_all_single(
processed_tokens,
processed_tokens_by_rank,
n_tokens_to_rank.tolist(),
n_tokens_from_rank.tolist(),
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They’re required for the round‑trip semantics here. out_src_idx/out_src_rank are the metadata needed by combine to route tokens back to their original rank and restore local order (we use them to build the alltoall splits and the final index_copy_).

out_topk_idx is still needed after dispatch so the rank can route tokens to its local experts.

about out_src_idx/out_src_rank, I can go with your suggestion and drop them, but I want to be explicit about the implication: we’d be committing to the same constrained ordering as the reference test you linked, that is, routing must be fully determined by split sizes and a fixed rank<->expert mapping, with per‑rank chunk order preserved by all‑to‑all. In that model you can reconstruct order without per‑token metadata. The trade‑off is that we lose support for arbitrary routing/non‑trivial meshes or any custom per‑token permutation, since we’d have no way to restore original order. One potential issue I foresee is implementing a padded dispatch/combined on some over- and pre-allocated buffers -- but that could be solved later!

Assuming we are ok with those implications, I’ll proceed with the simplification.

Copy link
Collaborator Author

@samnordmann samnordmann Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming back on my last comment: note that DeepSeek API seems to use those tensors and their Combine specifically reorders back to original token positions (using src_idx/src_rank).
https://github.com/deepseek-ai/DeepEP/blob/29d31c095796f3c8ece47ee9cdcc167051bbeed9/csrc/kernels/intranode.cu#L1034

Is this an argument to keep those variables?

As a side note, related to #5857 (comment), note that DeepSeek interface takes topk_weights

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thoughts on that? Imho reducing the number of arguments doesn't necessarily simplify the IR because it introduces implicit assumptions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that DeepSeek interface takes topk_weights

Good point. I suspect this is because their combine kernel also covers the topk_weights multiplication. It's suboptimal to materialize s*k/ep in global memory. This PR probably doesn't need that at this very moment because it assumes k=1. I'll double check the code to make sure.

What does your kernel look like? I believe this is just a reference implementation and you are about to add a fused kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll double check the code to make sure.

I got a chance to read the code. Thanks for the pointer! There are three modes:

  • Low-latency (see internode_ll.cu). Inter-node. One-hop.
  • Intra-node (see intranode.cu). One-hop.
  • Inter-node (see internode.cu). Two-hop, first RDMA and then nvlink.

Regardless of which mode, combine reads topk_weights. When k>1, it's better to fuse topk_weights multiplication into the combine kernel to avoid materializing s*k/ep tokens in global memory.

According to deepseek-ai/DeepEP#72, dispatch needs to produce topk_weights for local reduction. Therefore, dispatch does so for inter-node and doesn't for low-latency. It's yet unclear why dispatch produces topk_weights for intra-node in some cases.

What kernel? I am not sure to understand. I am not writing any cuda kernel in this PR

Never mind. I was referring to the kernel in the other PR you showed on Thursday. I wasn't sure you were building kernels from scratch or reusing DeepEP's. Since the former, we have more motivations to keep implementations simple until complexity is required.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_src_idx makes sense to me now. I previously thought a local sort was performed before dispatch, but I now understand that dispatch itself handles that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure about out_src_rank. Does it equal out_src_idx / ceil(seq_len / EP)?

Copy link
Collaborator Author

@samnordmann samnordmann Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll double check the code to make sure.

I got a chance to read the code. Thanks for the pointer! There are three modes:

  • Low-latency (see internode_ll.cu). Inter-node. One-hop.
  • Intra-node (see intranode.cu). One-hop.
  • Inter-node (see internode.cu). Two-hop, first RDMA and then nvlink.

Regardless of which mode, combine reads topk_weights. When k>1, it's better to fuse topk_weights multiplication into the combine kernel to avoid materializing s*k/ep tokens in global memory.

Yes, this is an important optimization. Therefore keeping topk_weights in the interface so in the next PR this will be fused with the combine's communication. Ok?

According to deepseek-ai/DeepEP#72, dispatch needs to produce topk_weights for local reduction. Therefore, dispatch does so for inter-node and doesn't for low-latency. It's yet unclear why dispatch produces topk_weights for intra-node in some cases.

IIUC, this comment discusses whether topk_weights should be present in dispatch interface. The answer is that it is not necessary because topk_weights is only needed by combineafter receiving the data, when the tokens have been reordered to their original order (so each rank keeps its topk_weights unsorted and uses them only after the tokens are back at their original rank and order).

However, I disagree with this answer. The optimal implementation is to fuse combine and the reduction, as you suggested just above, not "combine" followed by "local reduction". We do not need to materialize all the experts activations at the receiver side, only the weighted average. This is equivalent to comparing allreduce and allgather+local-reduce. So we need the weights at the sender side, before reordering the tokens.

My conclusion is that we should keep topk_weights both in combine and dispatch. Do you agree?

out_src_idx makes sense to me now. I previously thought a local sort was performed before dispatch, but I now understand that dispatch itself handles that.

Yes, it would be too rigid and suboptimal to request the tokens to be sorted before dispatch -- dispatch would merely be an MPI-alltoallv (for k=1). We do not need to reorder the tokens, especially when using raw nvLink transport as in the next PR. In the communication kernel, we can account for discountinuity in the source buffer. That's actually crucial for latency and BW to avoid repacking the data at the sender side.

I'm still not sure about out_src_rank. Does it equal out_src_idx / ceil(seq_len / EP)?

I don't think so. Tokens are not guaranteed to be evenly dispatched to experts.

Note that out_src_idx is the index local to the rank. The "global" index of the token is deduced from out_src_idx and out_src_rank.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My conclusion is that we should keep topk_weights both in combine and dispatch. Do you agree?

Yes, I agree now. IIUC, even for one-hop combines, it's still better to apply topk_weights on the sender side because the reduction can be done in fabric.

https://github.com/deepseek-ai/DeepEP/blob/29d31c095796f3c8ece47ee9cdcc167051bbeed9/csrc/kernels/internode_ll.cu#L1078 might just be suboptimal because it requires topk_weight to be in sequence order. Could it be an RDMA-related constraint? Note this kernel is for inter-nvlink-domain.

Note that out_src_idx is the index local to the rank.

Got it -- I missed that.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@samnordmann samnordmann requested a review from nsarka January 27, 2026 15:12
@samnordmann samnordmann force-pushed the dispatch_combine/stub branch from 6da9793 to 4693c53 Compare January 29, 2026 14:03
@samnordmann samnordmann requested a review from wujingyue January 29, 2026 14:04
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

// Asymmetric example:
// token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens.
auto rank_ids = at::arange(world_size, int_options);
auto token_rank = at::tensor({0, 1, 1, 1}, int_options);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hardcoded to 4 tokens with routing [0, 1, 1, 1] - won't work for world_size > 2

auto topk_idx_flat = topk_idx.reshape({num_tokens});

// Determine destination rank per token (topk=1).
auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assumes is_token_in_rank is valid one-hot (exactly one 1 per row). If multiple ranks are marked or none are marked, argmax returns first/last index but routing will be incorrect. Add validation that each row has exactly one True value.

topk_is_1d || topk_is_2d,
"Only topk=1 supported. topk_idx must be shape [T] or [T, 1], got: ",
topk_idx.sizes());
auto topk_idx_flat = topk_idx.reshape({num_tokens});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no validation that topk_idx values are within [0, num_experts) range. Invalid expert IDs could cause incorrect rank assignment or out-of-bounds issues in local expert computation.

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

recv_src_rank, send_src_rank, output_splits, input_splits));

// Locally reorder by expert id so each rank processes contiguous experts.
auto local_expert = recv_topk_idx - my_rank * experts_per_rank;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential out-of-bounds risk: local_expert = recv_topk_idx - my_rank * experts_per_rank could produce negative values if recv_topk_idx contains expert IDs that don't belong to this rank. While the subsequent argsort will still work, it may produce unexpected results during expert processing if the assumption is that only local expert IDs should be received.

// Asymmetric example:
// token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens.
auto rank_ids = at::arange(world_size, int_options);
auto token_rank = at::tensor({0, 1, 1, 1}, int_options);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded routing pattern [0, 1, 1, 1] only works for world_size=2. For world_size > 2, this creates is_token_in_rank with more columns than the routing pattern accounts for, causing rank 2+ to receive no tokens. Consider making the test routing scale with world_size or adding a GTEST_SKIP() for world_size > 2.

Comment on lines +25 to +30
int64_t sumSplitSizes(const std::vector<int64_t>& splits) {
int64_t total = 0;
for (auto value : splits) {
total += value;
}
return total;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing overflow check for large tensors. Summing split sizes without overflow checking could cause issues with very large token counts. Consider using checked arithmetic or validating that the sum doesn't exceed reasonable bounds.

auto topk_idx_flat = topk_idx.reshape({num_tokens});

// Determine destination rank per token (topk=1).
auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .to(at::kLong) conversions are redundant. is_token_in_rank is already cast to at::kLong on the same line, and argmax(1) returns int64 by default.

@samnordmann
Copy link
Collaborator Author

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +19 to +23
std::vector<int64_t> toSplitSizes(const at::Tensor& sizes_tensor) {
auto cpu_sizes = sizes_tensor.to(at::kCPU);
auto* ptr = cpu_sizes.data_ptr<int64_t>();
return std::vector<int64_t>(ptr, ptr + cpu_sizes.numel());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe dtype assumption

toSplitSizes unconditionally does cpu_sizes.data_ptr<int64_t>() after .to(at::kCPU), but it never checks that sizes_tensor is actually kLong/int64. If a caller passes an int32 count tensor (e.g., from a different backend or future refactor), this will throw (data_ptr<int64_t>() called on ...) at runtime. Consider asserting cpu_sizes.scalar_type() == at::kLong (or converting) before taking the pointer.

Comment on lines +45 to +51
auto* recv_x = makeSymbolicTensor(2);
auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int);
auto* recv_topk_weights = makeSymbolicTensor(1);
auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int);
auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int);
auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int);
auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IR dtype mismatch

This test declares recv_topk_idx, recv_src_idx, recv_src_rank, n_tokens_to_rank, and n_tokens_from_rank as DataType::Int, but the implementation returns/uses at::kLong for these tensors (e.g., send_src_idx = sorted_indices.to(at::kLong) and bincount(...).to(at::kLong)). On platforms where DataType::Int maps to 32-bit, Host IR validation/binding will fail. The symbolic tensor dtypes should match the actual tensors produced (likely DataType::Int64).

Comment on lines +275 to +279
auto combined_x = at::empty({total_recv, hidden}, x.options());
combined_x.index_copy_(0, recv_src_idx, recv_x);
auto combined_topk_weights =
at::empty({total_recv}, topk_weights_flat.options());
combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Index collision drops data

combined_x.index_copy_(0, recv_src_idx, recv_x) (and same for combined_topk_weights) will silently overwrite when recv_src_idx contains duplicates (possible if upstream routing produces repeated source indices). That yields incorrect combined outputs with no error. If uniqueness is an invariant, it needs to be enforced/validated before index_copy_ so failures are explicit.


struct CombineResult {
at::Tensor combined_x; // Combined tokens back in original order.
at::Tensor combined_topk_weights; // Combined gating weights per token.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be removed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants