diff --git a/CMakeLists.txt b/CMakeLists.txt index d21425d2e9e..d007b40c90d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,6 +235,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/multidevice/communication.cpp ${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp ${NVFUSER_SRCS_DIR}/multidevice/cuda_p2p.cpp + ${NVFUSER_SRCS_DIR}/multidevice/dispatch_combine.cpp ${NVFUSER_SRCS_DIR}/multidevice/ipc_handle.cpp ${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp @@ -1143,6 +1144,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 7c53c86d903..bcf35ff5e55 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -118,6 +118,8 @@ class Val; f(Merge); \ f(Partition); \ f(Combine); \ + f(MoeDispatch); \ + f(MoeCombine); \ f(Swizzle); \ f(Resize); \ f(MatmulOp); \ diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 2396767d5b0..7d28c6e0755 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -25,6 +25,7 @@ #include "multidevice/allocation_utils.h" #include "multidevice/communication.h" #include "multidevice/cuda_p2p.h" +#include "multidevice/dispatch_combine.h" #include "multidevice/execution_utils.h" #include "multidevice/symmetric_tensor.h" #include "multidevice/utils.h" @@ -386,6 +387,66 @@ void HostIrEvaluator::handle(P2PCommunication* communication) { } } +void HostIrEvaluator::handle(MoeDispatch* dispatch) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + auto x = getKnownConcreteValue(dispatch->inX()).as(); + auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as(); + auto topk_weights = + getKnownConcreteValue(dispatch->inTopkWeights()).as(); + auto is_token_in_rank = + getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); + + auto result = doMoeDispatch( + x, + topk_idx, + topk_weights, + is_token_in_rank, + dispatch->numExperts(), + communicator_, + dispatch->backend()); + + expr_evaluator_.bind(dispatch->outX(), result.recv_x); + expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx); + expr_evaluator_.bind(dispatch->outTopkWeights(), result.recv_topk_weights); + expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx); + expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank); + expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank); + expr_evaluator_.bind( + dispatch->outTokensFromRank(), result.n_tokens_from_rank); +} + +void HostIrEvaluator::handle(MoeCombine* combine) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + auto x = getKnownConcreteValue(combine->inX()).as(); + auto topk_weights = + getKnownConcreteValue(combine->inTopkWeights()).as(); + auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as(); + auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as(); + auto n_tokens_to_rank = + getKnownConcreteValue(combine->inTokensToRank()).as(); + auto n_tokens_from_rank = + getKnownConcreteValue(combine->inTokensFromRank()).as(); + + auto result = doMoeCombine( + x, + topk_weights, + src_idx, + src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + communicator_, + combine->backend()); + + expr_evaluator_.bind(combine->outX(), result.combined_x); + expr_evaluator_.bind(combine->outTopkWeights(), result.combined_topk_weights); +} + void HostIrEvaluator::handle(Wait* wait) { Expr* expr = wait->communication(); auto* p2p_comm = dynamic_cast(expr); diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index 22833156cab..4a1929ba1bd 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -98,6 +98,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { void handle(LaunchKernel*) override; void handle(Communication*) override; void handle(P2PCommunication*) override; + void handle(MoeDispatch*) override; + void handle(MoeCombine*) override; void handle(Wait*) override; void handle(kir::ForLoop*) override; void handle(hir::ForLoop*) override; diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 6778da9da71..b790748f957 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -321,6 +321,169 @@ std::string P2PCommunication::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } +MoeDispatch::MoeDispatch( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_idx, + TensorView* out_topk_weights, + TensorView* out_src_idx, + TensorView* out_src_rank, + TensorView* out_n_tokens_to_rank, + TensorView* out_n_tokens_from_rank, + TensorView* in_x, + TensorView* in_topk_idx, + TensorView* in_topk_weights, + TensorView* in_is_token_in_rank, + int64_t num_experts, + CommunicatorBackend backend) + : Expr(passkey) { + addInput(in_x); + addInput(in_topk_idx); + addInput(in_topk_weights); + addInput(in_is_token_in_rank); + addOutput(out_x); + addOutput(out_topk_idx); + addOutput(out_topk_weights); + addOutput(out_src_idx); + addOutput(out_src_rank); + addOutput(out_n_tokens_to_rank); + addOutput(out_n_tokens_from_rank); + addDataAttribute(num_experts); + addDataAttribute(backend); + validate(); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MoeDispatch) + +std::string MoeDispatch::toInlineString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "Dispatch " << name() << " (" + << "num_experts=" << numExperts() << ", " + << "backend=" << backend() << ", " + << "in=" << inX() << ", " + << "topk_idx=" << inTopkIdx() << ", " + << "topk_weights=" << inTopkWeights() << ", " + << "is_token_in_rank=" << inIsTokenInRank() << ", " + << "out=" << outX() << ")"; + return ss.str(); +} + +std::string MoeDispatch::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + +void MoeDispatch::validate() { + NVF_CHECK(numExperts() > 0, "num_experts must be positive."); + NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); + NVF_CHECK(inTopkIdx()->isA(), "topk_idx must be a TensorView."); + NVF_CHECK( + inTopkIdx()->getDataType().has_value() && + isIntegralType(*inTopkIdx()->getDataType()), + "topk_idx must be integral."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "topk_weights must be floating point."); + NVF_CHECK( + inIsTokenInRank()->getDataType().has_value() && + inIsTokenInRank()->getDataType() == DataType::Bool, + "is_token_in_rank must be Bool."); + NVF_CHECK( + outTopkIdx()->getDataType().has_value() && + isIntegralType(*outTopkIdx()->getDataType()), + "out_topk_idx must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); + NVF_CHECK( + outSrcIdx()->getDataType().has_value() && + isIntegralType(*outSrcIdx()->getDataType()), + "out_src_idx must be integral."); + NVF_CHECK( + outSrcRank()->getDataType().has_value() && + isIntegralType(*outSrcRank()->getDataType()), + "out_src_rank must be integral."); + NVF_CHECK( + outTokensToRank()->getDataType().has_value() && + isIntegralType(*outTokensToRank()->getDataType()), + "out_n_tokens_to_rank must be integral."); + NVF_CHECK( + outTokensFromRank()->getDataType().has_value() && + isIntegralType(*outTokensFromRank()->getDataType()), + "out_n_tokens_from_rank must be integral."); +} + +MoeCombine::MoeCombine( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_weights, + TensorView* in_x, + TensorView* in_topk_weights, + TensorView* in_src_idx, + TensorView* in_src_rank, + TensorView* in_n_tokens_to_rank, + TensorView* in_n_tokens_from_rank, + CommunicatorBackend backend) + : Expr(passkey) { + addInput(in_x); + addInput(in_topk_weights); + addInput(in_src_idx); + addInput(in_src_rank); + addInput(in_n_tokens_to_rank); + addInput(in_n_tokens_from_rank); + addOutput(out_x); + addOutput(out_topk_weights); + addDataAttribute(backend); + validate(); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MoeCombine) + +std::string MoeCombine::toInlineString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "Combine " << name() << " (" + << "backend=" << backend() << ", " + << "in=" << inX() << ", " + << "topk_weights=" << inTopkWeights() << ", " + << "src_idx=" << inSrcIdx() << ", " + << "src_rank=" << inSrcRank() << ", " + << "out=" << outX() << ")"; + return ss.str(); +} + +std::string MoeCombine::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + +void MoeCombine::validate() { + NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "in_topk_weights must be floating point."); + NVF_CHECK( + inSrcIdx()->getDataType().has_value() && + isIntegralType(*inSrcIdx()->getDataType()), + "in_src_idx must be integral."); + NVF_CHECK( + inSrcRank()->getDataType().has_value() && + isIntegralType(*inSrcRank()->getDataType()), + "in_src_rank must be integral."); + NVF_CHECK( + inTokensToRank()->getDataType().has_value() && + isIntegralType(*inTokensToRank()->getDataType()), + "in_n_tokens_to_rank must be integral."); + NVF_CHECK( + inTokensFromRank()->getDataType().has_value() && + isIntegralType(*inTokensFromRank()->getDataType()), + "in_n_tokens_from_rank must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); +} + namespace { c10::intrusive_ptr postBroadcast( Communication* communication, diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 1a7f1a1cc4c..f4a1abaf667 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -174,6 +174,183 @@ class P2PCommunication : public Expr { } }; +// Dispatch represents intra-node MoE token dispatch. It shuffles tokens from +// the local rank to destination ranks based on explicit routing. +// +// Example shapes (topk=1): +// in_x: [T, H], in_topk_idx: [T] or [T, 1], +// in_topk_weights: [T] or [T, 1], +// in_is_token_in_rank: [T, R] (one-hot), num_experts = R * experts_per_rank. +// out_src_idx/out_src_rank are returned for the combine step to restore the +// original token order. +// Outputs are recv-aligned tensors: out_x/out_topk_idx/out_topk_weights/ +// out_src_* with [T_recv, ...] and +// out_n_tokens_to_rank/out_n_tokens_from_rank with shape [R]. +class MoeDispatch : public Expr { + public: + using Expr::Expr; + + MoeDispatch( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_idx, + TensorView* out_topk_weights, + TensorView* out_src_idx, + TensorView* out_src_rank, + TensorView* out_n_tokens_to_rank, + TensorView* out_n_tokens_from_rank, + TensorView* in_x, + TensorView* in_topk_idx, + TensorView* in_topk_weights, + TensorView* in_is_token_in_rank, + int64_t num_experts, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + TensorView* inIsTokenInRank() const { + return input(3)->as(); + } + + MoeDispatch(const MoeDispatch& other) = delete; + MoeDispatch& operator=(const MoeDispatch& other) = delete; + MoeDispatch(MoeDispatch&& other) = delete; + MoeDispatch& operator=(MoeDispatch&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "MoeDispatch"; + } + + TensorView* outX() const { + return output(0)->as(); + } + + TensorView* outTopkIdx() const { + return output(1)->as(); + } + + TensorView* outTopkWeights() const { + return output(2)->as(); + } + + TensorView* outSrcIdx() const { + return output(3)->as(); + } + + TensorView* outSrcRank() const { + return output(4)->as(); + } + + TensorView* outTokensToRank() const { + return output(5)->as(); + } + + TensorView* outTokensFromRank() const { + return output(6)->as(); + } + + TensorView* inX() const { + return input(0)->as(); + } + + TensorView* inTopkIdx() const { + return input(1)->as(); + } + + TensorView* inTopkWeights() const { + return input(2)->as(); + } + + int64_t numExperts() const { + return attribute(0); + } + + CommunicatorBackend backend() const { + return attribute(1); + } + + private: + void validate(); +}; + +// Combine represents intra-node MoE token combine. It shuffles tokens back to +// their source ranks using `in_src_rank` and `in_src_idx`. +// +// Example shapes (topk=1): +// in_x: [T_recv, H], in_topk_weights: [T_recv], in_src_idx: [T_recv], +// in_src_rank: [T_recv], in_n_tokens_to_rank: [R], in_n_tokens_from_rank: +// [R]. Outputs are source-aligned: out_x/out_topk_weights with shape [T_src, +// ...]. +class MoeCombine : public Expr { + public: + using Expr::Expr; + + MoeCombine( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_weights, + TensorView* in_x, + TensorView* in_topk_weights, + TensorView* in_src_idx, + TensorView* in_src_rank, + TensorView* in_n_tokens_to_rank, + TensorView* in_n_tokens_from_rank, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + MoeCombine(const MoeCombine& other) = delete; + MoeCombine& operator=(const MoeCombine& other) = delete; + MoeCombine(MoeCombine&& other) = delete; + MoeCombine& operator=(MoeCombine&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "MoeCombine"; + } + + TensorView* outX() const { + return output(0)->as(); + } + + TensorView* outTopkWeights() const { + return output(1)->as(); + } + + TensorView* inX() const { + return input(0)->as(); + } + + TensorView* inTopkWeights() const { + return input(1)->as(); + } + + TensorView* inSrcIdx() const { + return input(2)->as(); + } + + TensorView* inSrcRank() const { + return input(3)->as(); + } + + TensorView* inTokensToRank() const { + return input(4)->as(); + } + + TensorView* inTokensFromRank() const { + return input(5)->as(); + } + + CommunicatorBackend backend() const { + return attribute(0); + } + + private: + void validate(); +}; + // The method "post" triggers the execution of the communication. This call is // non-blocking. The communication can be posted multiple times. // It is assumed that the current device_index (given by diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp new file mode 100644 index 00000000000..043b37fd421 --- /dev/null +++ b/csrc/multidevice/dispatch_combine.cpp @@ -0,0 +1,284 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include "multidevice/dispatch_combine.h" + +#include + +#include "exceptions.h" +#include "multidevice/communicator.h" + +namespace nvfuser { +namespace { + +std::vector toSplitSizes(const at::Tensor& sizes_tensor) { + auto cpu_sizes = sizes_tensor.to(at::kCPU); + auto* ptr = cpu_sizes.data_ptr(); + return std::vector(ptr, ptr + cpu_sizes.numel()); +} + +int64_t sumSplitSizes(const std::vector& splits) { + int64_t total = 0; + for (auto value : splits) { + total += value; + } + return total; +} + +void waitWork(const c10::intrusive_ptr& work) { + if (work) { + work->wait(); + } +} + +} // namespace + +DispatchResult doMoeDispatch( + const at::Tensor& x, + const at::Tensor& topk_idx, + const at::Tensor& topk_weights, + const at::Tensor& is_token_in_rank, + int64_t num_experts, + Communicator* communicator, + CommunicatorBackend backend) { + NVF_CHECK(communicator != nullptr, "Dispatch requires a valid communicator."); + NVF_CHECK(x.is_cuda(), "Dispatch input x must be on CUDA."); + NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Dispatch topk_weights must be on CUDA."); + NVF_CHECK( + topk_weights.is_floating_point(), + "Dispatch topk_weights must be floating point."); + NVF_CHECK( + is_token_in_rank.is_cuda(), "Dispatch is_token_in_rank must be on CUDA."); + NVF_CHECK( + x.device() == topk_idx.device(), + "Dispatch expects x and topk_idx on the same device."); + NVF_CHECK( + x.device() == topk_weights.device(), + "Dispatch expects x and topk_weights on the same device."); + NVF_CHECK( + x.device() == is_token_in_rank.device(), + "Dispatch expects x and is_token_in_rank on the same device."); + NVF_CHECK_EQ( + is_token_in_rank.dim(), + 2, + "is_token_in_rank must be [tokens, ranks], got: ", + is_token_in_rank.sizes()); + NVF_CHECK_EQ(x.dim(), 2, "Dispatch expects x to be 2D [tokens, hidden]."); + + const int64_t num_tokens = x.size(0); + const int64_t hidden = x.size(1); + const int64_t world_size = communicator->size(); + const int64_t my_rank = communicator->deviceId(); + NVF_CHECK_EQ( + is_token_in_rank.size(0), + num_tokens, + "is_token_in_rank first dim must match number of tokens."); + NVF_CHECK_EQ( + is_token_in_rank.size(1), + world_size, + "is_token_in_rank second dim must match world size."); + NVF_CHECK_EQ(num_experts % world_size, 0, "num_experts must be divisible."); + const int64_t experts_per_rank = num_experts / world_size; + + const bool topk_is_1d = topk_idx.dim() == 1 && topk_idx.size(0) == num_tokens; + const bool topk_is_2d = topk_idx.dim() == 2 && + topk_idx.size(0) == num_tokens && topk_idx.size(1) == 1; + NVF_CHECK( + topk_is_1d || topk_is_2d, + "Only topk=1 supported. topk_idx must be shape [T] or [T, 1], got: ", + topk_idx.sizes()); + auto topk_idx_flat = topk_idx.reshape({num_tokens}); + const bool weights_is_1d = + topk_weights.dim() == 1 && topk_weights.size(0) == num_tokens; + const bool weights_is_2d = topk_weights.dim() == 2 && + topk_weights.size(0) == num_tokens && topk_weights.size(1) == 1; + NVF_CHECK( + weights_is_1d || weights_is_2d, + "Only topk=1 supported. topk_weights must be shape [T] or [T, 1], got: ", + topk_weights.sizes()); + auto topk_weights_flat = topk_weights.reshape({num_tokens}); + + // Determine destination rank per token (topk=1). + auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + // Sort tokens by destination rank for contiguous alltoall slices. + auto sorted_indices = at::argsort(rank_for_token); + + // Reorder payloads so alltoall can send contiguous chunks per rank. + auto send_x = x.index_select(0, sorted_indices); + auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + // Track original token indices and source rank for the combine step. + auto send_src_idx = sorted_indices.to(at::kLong); + // All entries are identical, so no relayout is needed. + auto send_src_rank = at::full( + {num_tokens}, + my_rank, + at::TensorOptions().dtype(at::kLong).device(x.device())); + + // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we + // sync/copy here. GPU-initiated comms can avoid this extra sync. + auto rank_for_token_cpu = rank_for_token.to(at::kCPU); + auto n_tokens_to_rank_cpu = + at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); + auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); + auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + + NVF_CHECK_EQ( + backend, + CommunicatorBackend::kNccl, + "Only NCCL backend is supported for MoeDispatch."); + CommunicatorBackend actual_backend = backend; + NVF_CHECK( + communicator->isBackendAvailable(actual_backend), + "Backend not available for dispatch: ", + actual_backend); + auto* pg = communicator->getWorld(actual_backend); + NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + + // Exchange per-rank token counts to build split sizes for alltoall. + std::vector one_split(world_size, 1); + waitWork(pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); + + // Convert count tensors to CPU split vectors and size the receive buffers. + auto input_splits = toSplitSizes(n_tokens_to_rank); + auto output_splits = toSplitSizes(n_tokens_from_rank); + auto total_recv = sumSplitSizes(output_splits); + + // Allocate receive buffers for payloads and metadata. + // TODO: support preallocated buffers. + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); + auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); + auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); + + // Alltoall exchange payloads with per-rank splits. + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)); + + // Locally reorder by expert id so each rank processes contiguous experts. + auto local_expert = recv_topk_idx - my_rank * experts_per_rank; + auto expert_order = at::argsort(local_expert); + recv_x = recv_x.index_select(0, expert_order); + recv_topk_idx = recv_topk_idx.index_select(0, expert_order); + recv_topk_weights = recv_topk_weights.index_select(0, expert_order); + recv_src_idx = recv_src_idx.index_select(0, expert_order); + recv_src_rank = recv_src_rank.index_select(0, expert_order); + + return DispatchResult{ + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank}; +} + +CombineResult doMoeCombine( + const at::Tensor& x, + const at::Tensor& topk_weights, + const at::Tensor& src_idx, + const at::Tensor& src_rank, + const at::Tensor& n_tokens_to_rank, + const at::Tensor& n_tokens_from_rank, + Communicator* communicator, + CommunicatorBackend backend) { + NVF_CHECK(communicator != nullptr, "Combine requires a valid communicator."); + NVF_CHECK(x.is_cuda(), "Combine input x must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); + NVF_CHECK( + topk_weights.is_floating_point(), + "Combine topk_weights must be floating point."); + NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); + NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); + NVF_CHECK( + n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); + NVF_CHECK( + n_tokens_from_rank.is_cuda(), "Combine n_tokens_from_rank must be CUDA."); + NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D [tokens, hidden]."); + NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D."); + NVF_CHECK_EQ(src_rank.dim(), 1, "src_rank must be 1D."); + const bool weights_is_1d = + topk_weights.dim() == 1 && topk_weights.size(0) == x.size(0); + const bool weights_is_2d = topk_weights.dim() == 2 && + topk_weights.size(0) == x.size(0) && topk_weights.size(1) == 1; + NVF_CHECK( + weights_is_1d || weights_is_2d, + "topk_weights must be shape [T] or [T, 1], got: ", + topk_weights.sizes()); + auto topk_weights_flat = topk_weights.reshape({x.size(0)}); + NVF_CHECK_EQ( + src_idx.size(0), x.size(0), "src_idx size must match x first dimension."); + NVF_CHECK_EQ( + src_rank.size(0), + x.size(0), + "src_rank size must match x first dimension."); + NVF_CHECK_EQ( + n_tokens_to_rank.numel(), + communicator->size(), + "n_tokens_to_rank must match world size."); + NVF_CHECK_EQ( + n_tokens_from_rank.numel(), + communicator->size(), + "n_tokens_from_rank must match world size."); + + // Sort by source rank so alltoall can send contiguous chunks per rank. + auto sorted_indices = at::argsort(src_rank); + auto send_x = x.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + auto send_src_idx = src_idx.index_select(0, sorted_indices); + + // Split sizes come from dispatch counts. + auto input_splits = toSplitSizes(n_tokens_from_rank); + auto output_splits = toSplitSizes(n_tokens_to_rank); + auto total_recv = sumSplitSizes(output_splits); + auto hidden = x.size(1); + + NVF_CHECK( + backend == CommunicatorBackend::kNccl, + "Only NCCL backend is supported for MoeCombine."); + CommunicatorBackend actual_backend = backend; + NVF_CHECK( + communicator->isBackendAvailable(actual_backend), + "Backend not available for combine: ", + actual_backend); + auto* pg = communicator->getWorld(actual_backend); + NVF_CHECK(pg != nullptr, "Combine backend is null."); + + // Allocate receive buffers and exchange payloads back to source ranks. + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); + auto recv_src_idx = at::empty({total_recv}, src_idx.options()); + + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + + // Scatter by original token index to restore local order. + auto combined_x = at::empty({total_recv, hidden}, x.options()); + combined_x.index_copy_(0, recv_src_idx, recv_x); + auto combined_topk_weights = + at::empty({total_recv}, topk_weights_flat.options()); + combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); + + return CombineResult{combined_x, combined_topk_weights}; +} + +} // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h new file mode 100644 index 00000000000..9c4f4e6a62a --- /dev/null +++ b/csrc/multidevice/dispatch_combine.h @@ -0,0 +1,125 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include "multidevice/communicator.h" +#include "visibility.h" + +namespace nvfuser { + +struct DispatchResult { + at::Tensor recv_x; // Dispatched tokens received on this rank. + at::Tensor recv_topk_idx; // Expert ids aligned with recv_x. + at::Tensor recv_topk_weights; // Gating weights aligned with recv_x. + at::Tensor recv_src_idx; // Source token indices for combine. + at::Tensor recv_src_rank; // Source ranks for combine. + at::Tensor n_tokens_to_rank; // Tokens sent to each rank (this rank's view). + at::Tensor n_tokens_from_rank; // Tokens received from each rank. +}; + +struct CombineResult { + at::Tensor combined_x; // Combined tokens back in original order. + at::Tensor combined_topk_weights; // Combined gating weights per token. +}; + +// Dispatch MoE tokens to the owning ranks. Only k=1 is supported for now. +// +// Args: +// x: Token embeddings on this rank, shape [T, H]. +// topk_idx: Global expert ids per token (topk=1), shape [T] or [T, 1]. +// topk_weights: Gating weights per token (topk=1), shape [T] or [T, 1]. +// is_token_in_rank: One-hot token-to-rank assignment, shape [T, R], enabling +// non-trivial device meshes or uneven expert-to-rank mappings. +// num_experts: Total experts across all ranks (must be divisible by R). +// communicator: Communicator for alltoall exchange. +// backend: Communication backend (only NCCL is supported for now). +// +// Returns: +// DispatchResult with recv_* tensors on this rank. +// +// Example: +// world_size=2, num_experts=4, T=4, H=2, topk=1 +// Experts are partitioned by rank: +// rank0 owns experts {0, 1}, rank1 owns experts {2, 3} +// Rank0 holds tokens 0,1 and rank1 holds tokens 2,3 in x: +// rank0 x = [x0, x1], rank1 x = [x2, x3] +// token->rank: [0, 1, 1, 1] (rank0 keeps x0, sends x1; rank1 keeps x2,x3) +// is_token_in_rank = +// [[1, 0], +// [0, 1], +// [0, 1], +// [0, 1]] +// topk_idx = [0, 2, 3, 2] (global expert ids) +// After dispatch on rank0: +// recv_x has token {0} +// recv_topk_idx aligned with recv_x (e.g., [0]) +// recv_topk_weights aligned with recv_x (e.g., [1.0]) +// recv_src_idx tells original token positions (e.g., [0]) +// After dispatch on rank1: +// recv_x has tokens {1, 2, 3} +// recv_topk_idx aligned with recv_x (e.g., [2, 2, 3]). Tokens are grouped +// by expert id for local expert processing. +// recv_src_idx tells original token positions (e.g., [1, 2, 3]) +// auto out = doMoeDispatch( +// x, +// topk_idx, +// topk_weights, +// is_token_in_rank, +// 4, +// comm, +// CommunicatorBackend::kNccl); +NVF_API DispatchResult doMoeDispatch( + const at::Tensor& x, // [T, H] + const at::Tensor& topk_idx, // [T] or [T, 1] + const at::Tensor& topk_weights, // [T] or [T, 1] + const at::Tensor& is_token_in_rank, // [T, R] + int64_t num_experts, + Communicator* communicator, + CommunicatorBackend backend); + +// Combine dispatched MoE results back to original token order. +// +// Args: +// x: Token embeddings after expert compute, shape [T_recv, H]. +// topk_weights: Gating weights aligned with x, shape [T_recv] or [T_recv, 1]. +// src_idx: Original token indices for each row of x, shape [T_recv]. +// src_rank: Original source rank per token, shape [T_recv]. +// n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. +// n_tokens_from_rank: Tokens received from each rank (from dispatch), shape +// [R]. +// communicator: Communicator for alltoall exchange. +// backend: Communication backend (only NCCL is supported for now). +// +// Returns: +// CombineResult with tokens restored to original order on this rank. +// +// Example: +// // Continuing the dispatch example (experts partitioned by rank): +// // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} +// // After expert compute: +// // rank0 recv_x has token {0} with src_idx = [0], src_rank = [0] +// // rank1 recv_x has tokens {1, 2, 3} with src_idx = [1, 2, 3], +// // src_rank = [0, 1, 1] +// // n_tokens_to_rank and n_tokens_from_rank are [R] counts per rank. +// // Combine scatters results back to original token order per rank. +// auto combined = doMoeCombine( +// x, topk_weights, src_idx, src_rank, n_tokens_to_rank, +// n_tokens_from_rank, comm, CommunicatorBackend::kNccl); +NVF_API CombineResult doMoeCombine( + const at::Tensor& x, + const at::Tensor& topk_weights, + const at::Tensor& src_idx, + const at::Tensor& src_rank, + const at::Tensor& n_tokens_to_rank, + const at::Tensor& n_tokens_from_rank, + Communicator* communicator, + CommunicatorBackend backend); + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp new file mode 100644 index 00000000000..22db18066f3 --- /dev/null +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -0,0 +1,132 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include + +#include "fusion.h" +#include "host_ir/container.h" +#include "host_ir/evaluator.h" +#include "multidevice/communication.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { +namespace hir { + +using DispatchCombineTest = MultiDeviceTest; + +TEST_F(DispatchCombineTest, DispatchCombineTop1) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + constexpr int64_t kNumExpertsPerRank = 2; + const int64_t num_experts = world_size * kNumExpertsPerRank; + constexpr int64_t kNumTokens = 4; + constexpr int64_t kHidden = 4; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* in_x = makeSymbolicTensor(2); + auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* in_topk_weights = makeSymbolicTensor(1); + auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool); + + auto* recv_x = makeSymbolicTensor(2); + auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_topk_weights = makeSymbolicTensor(1); + auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); + + auto* dispatch = IrBuilder::create( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + in_x, + in_topk_idx, + in_topk_weights, + in_is_token_in_rank, + num_experts, + CommunicatorBackend::kNccl); + + auto* combined_x = makeSymbolicTensor(2); + auto* combined_topk_weights = makeSymbolicTensor(1); + auto* combine = IrBuilder::create( + combined_x, + combined_topk_weights, + recv_x, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + CommunicatorBackend::kNccl); + + hic->pushBackTopLevelExprs(dispatch); + hic->pushBackTopLevelExprs(combine); + + hic->addInput(in_x); + hic->addInput(in_topk_idx); + hic->addInput(in_topk_weights); + hic->addInput(in_is_token_in_rank); + hic->addOutput(combined_x); + hic->addOutput(combined_topk_weights); + + HostIrEvaluator hie(std::move(hic), communicator_); + + auto float_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kFloat); + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto x = at::arange(kNumTokens * kHidden, float_options) + .reshape({kNumTokens, kHidden}) + + static_cast(my_rank) * 1000.0; + auto topk_idx = at::zeros({kNumTokens}, int_options); + auto topk_weights = + at::arange(kNumTokens, float_options) + static_cast(my_rank); + + // Asymmetric example: + // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. + auto rank_ids = at::arange(world_size, int_options); + auto token_rank = at::tensor({0, 1, 1, 1}, int_options); + auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); + + // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. + topk_idx.index_put_({0}, 0); + topk_idx.index_put_({1}, kNumExpertsPerRank); + topk_idx.index_put_({2}, kNumExpertsPerRank + 1); + topk_idx.index_put_({3}, kNumExpertsPerRank); + + auto outputs = hie.runWithInput( + {{in_x, x}, + {in_topk_idx, topk_idx}, + {in_topk_weights, topk_weights}, + {in_is_token_in_rank, is_token_in_rank}}); + auto combined = outputs[0].as(); + auto combined_weights = outputs[1].as(); + + EXPECT_TRUE(at::allclose(combined, x)) + << "Dispatch/Combine mismatch on rank " << my_rank; + EXPECT_TRUE(at::allclose(combined_weights, topk_weights)) + << "Dispatch/Combine topk_weights mismatch on rank " << my_rank; +} + +} // namespace hir +} // namespace nvfuser