Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/multidevice/communication.cpp
${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp
${NVFUSER_SRCS_DIR}/multidevice/cuda_p2p.cpp
${NVFUSER_SRCS_DIR}/multidevice/dispatch_combine.cpp
${NVFUSER_SRCS_DIR}/multidevice/ipc_handle.cpp
${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp
${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp
Expand Down Expand Up @@ -1143,6 +1144,7 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/tests/cpp/multidevice.cpp
${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp
Expand Down
2 changes: 2 additions & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class Val;
f(Merge); \
f(Partition); \
f(Combine); \
f(MoeDispatch); \
f(MoeCombine); \
f(Swizzle); \
f(Resize); \
f(MatmulOp); \
Expand Down
53 changes: 53 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "multidevice/allocation_utils.h"
#include "multidevice/communication.h"
#include "multidevice/cuda_p2p.h"
#include "multidevice/dispatch_combine.h"
#include "multidevice/execution_utils.h"
#include "multidevice/symmetric_tensor.h"
#include "multidevice/utils.h"
Expand Down Expand Up @@ -386,6 +387,58 @@ void HostIrEvaluator::handle(P2PCommunication* communication) {
}
}

void HostIrEvaluator::handle(MoeDispatch* dispatch) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

auto x = getKnownConcreteValue(dispatch->inX()).as<at::Tensor>();
auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as<at::Tensor>();
auto is_token_in_rank =
getKnownConcreteValue(dispatch->inIsTokenInRank()).as<at::Tensor>();

auto result = doMoeDispatch(
x,
topk_idx,
is_token_in_rank,
dispatch->numExperts(),
communicator_,
dispatch->backend());

expr_evaluator_.bind(dispatch->outX(), result.recv_x);
expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx);
expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx);
expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank);
expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank);
expr_evaluator_.bind(
dispatch->outTokensFromRank(), result.n_tokens_from_rank);
}

void HostIrEvaluator::handle(MoeCombine* combine) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

auto x = getKnownConcreteValue(combine->inX()).as<at::Tensor>();
auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as<at::Tensor>();
auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as<at::Tensor>();
auto n_tokens_to_rank =
getKnownConcreteValue(combine->inTokensToRank()).as<at::Tensor>();
auto n_tokens_from_rank =
getKnownConcreteValue(combine->inTokensFromRank()).as<at::Tensor>();

auto result = doMoeCombine(
x,
src_idx,
src_rank,
n_tokens_to_rank,
n_tokens_from_rank,
communicator_,
combine->backend());

expr_evaluator_.bind(combine->outX(), result.combined_x);
}

void HostIrEvaluator::handle(Wait* wait) {
Expr* expr = wait->communication();
auto* p2p_comm = dynamic_cast<P2PCommunication*>(expr);
Expand Down
2 changes: 2 additions & 0 deletions csrc/host_ir/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch {
void handle(LaunchKernel*) override;
void handle(Communication*) override;
void handle(P2PCommunication*) override;
void handle(MoeDispatch*) override;
void handle(MoeCombine*) override;
void handle(Wait*) override;
void handle(kir::ForLoop*) override;
void handle(hir::ForLoop*) override;
Expand Down
137 changes: 137 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,143 @@ std::string P2PCommunication::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

MoeDispatch::MoeDispatch(
IrBuilderPasskey passkey,
TensorView* out_x,
TensorView* out_topk_idx,
TensorView* out_src_idx,
TensorView* out_src_rank,
TensorView* out_n_tokens_to_rank,
TensorView* out_n_tokens_from_rank,
TensorView* in_x,
TensorView* in_topk_idx,
TensorView* in_is_token_in_rank,
int64_t num_experts,
CommunicatorBackend backend)
: Expr(passkey) {
addInput(in_x);
addInput(in_topk_idx);
addInput(in_is_token_in_rank);
addOutput(out_x);
addOutput(out_topk_idx);
addOutput(out_src_idx);
addOutput(out_src_rank);
addOutput(out_n_tokens_to_rank);
addOutput(out_n_tokens_from_rank);
addDataAttribute(num_experts);
addDataAttribute(backend);
validate();
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MoeDispatch)

std::string MoeDispatch::toInlineString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Dispatch " << name() << " ("
<< "num_experts=" << numExperts() << ", "
<< "backend=" << backend() << ", "
<< "in=" << inX() << ", "
<< "topk_idx=" << inTopkIdx() << ", "
<< "is_token_in_rank=" << inIsTokenInRank() << ", "
<< "out=" << outX() << ")";
return ss.str();
}

std::string MoeDispatch::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

void MoeDispatch::validate() {
NVF_CHECK(numExperts() > 0, "num_experts must be positive.");
NVF_CHECK(inX()->isA<TensorView>(), "in_x must be a TensorView.");
NVF_CHECK(inTopkIdx()->isA<TensorView>(), "topk_idx must be a TensorView.");
NVF_CHECK(
inTopkIdx()->getDataType().has_value() &&
isIntegralType(*inTopkIdx()->getDataType()),
"topk_idx must be integral.");
NVF_CHECK(
inIsTokenInRank()->getDataType().has_value() &&
inIsTokenInRank()->getDataType() == DataType::Bool,
"is_token_in_rank must be Bool.");
NVF_CHECK(
outTopkIdx()->getDataType().has_value() &&
isIntegralType(*outTopkIdx()->getDataType()),
"out_topk_idx must be integral.");
NVF_CHECK(
outSrcIdx()->getDataType().has_value() &&
isIntegralType(*outSrcIdx()->getDataType()),
"out_src_idx must be integral.");
NVF_CHECK(
outSrcRank()->getDataType().has_value() &&
isIntegralType(*outSrcRank()->getDataType()),
"out_src_rank must be integral.");
NVF_CHECK(
outTokensToRank()->getDataType().has_value() &&
isIntegralType(*outTokensToRank()->getDataType()),
"out_n_tokens_to_rank must be integral.");
NVF_CHECK(
outTokensFromRank()->getDataType().has_value() &&
isIntegralType(*outTokensFromRank()->getDataType()),
"out_n_tokens_from_rank must be integral.");
}

MoeCombine::MoeCombine(
IrBuilderPasskey passkey,
TensorView* out_x,
TensorView* in_x,
TensorView* in_src_idx,
TensorView* in_src_rank,
TensorView* in_n_tokens_to_rank,
TensorView* in_n_tokens_from_rank,
CommunicatorBackend backend)
: Expr(passkey) {
addInput(in_x);
addInput(in_src_idx);
addInput(in_src_rank);
addInput(in_n_tokens_to_rank);
addInput(in_n_tokens_from_rank);
addOutput(out_x);
addDataAttribute(backend);
validate();
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MoeCombine)

std::string MoeCombine::toInlineString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Combine " << name() << " ("
<< "backend=" << backend() << ", "
<< "in=" << inX() << ", "
<< "src_idx=" << inSrcIdx() << ", "
<< "src_rank=" << inSrcRank() << ", "
<< "out=" << outX() << ")";
return ss.str();
}

std::string MoeCombine::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

void MoeCombine::validate() {
NVF_CHECK(inX()->isA<TensorView>(), "in_x must be a TensorView.");
NVF_CHECK(
inSrcIdx()->getDataType().has_value() &&
isIntegralType(*inSrcIdx()->getDataType()),
"in_src_idx must be integral.");
NVF_CHECK(
inSrcRank()->getDataType().has_value() &&
isIntegralType(*inSrcRank()->getDataType()),
"in_src_rank must be integral.");
NVF_CHECK(
inTokensToRank()->getDataType().has_value() &&
isIntegralType(*inTokensToRank()->getDataType()),
"in_n_tokens_to_rank must be integral.");
NVF_CHECK(
inTokensFromRank()->getDataType().has_value() &&
isIntegralType(*inTokensFromRank()->getDataType()),
"in_n_tokens_from_rank must be integral.");
}

namespace {
c10::intrusive_ptr<c10d::Work> postBroadcast(
Communication* communication,
Expand Down
Loading