Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/multidevice/communication.cpp
${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp
${NVFUSER_SRCS_DIR}/multidevice/cuda_p2p.cpp
${NVFUSER_SRCS_DIR}/multidevice/dispatch_combine.cpp
${NVFUSER_SRCS_DIR}/multidevice/ipc_handle.cpp
${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp
${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp
Expand Down Expand Up @@ -1143,6 +1144,7 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/tests/cpp/multidevice.cpp
${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp
Expand Down
2 changes: 2 additions & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class Val;
f(Merge); \
f(Partition); \
f(Combine); \
f(MoeDispatch); \
f(MoeCombine); \
f(Swizzle); \
f(Resize); \
f(MatmulOp); \
Expand Down
61 changes: 61 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "multidevice/allocation_utils.h"
#include "multidevice/communication.h"
#include "multidevice/cuda_p2p.h"
#include "multidevice/dispatch_combine.h"
#include "multidevice/execution_utils.h"
#include "multidevice/symmetric_tensor.h"
#include "multidevice/utils.h"
Expand Down Expand Up @@ -386,6 +387,66 @@ void HostIrEvaluator::handle(P2PCommunication* communication) {
}
}

void HostIrEvaluator::handle(MoeDispatch* dispatch) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

auto x = getKnownConcreteValue(dispatch->inX()).as<at::Tensor>();
auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as<at::Tensor>();
auto topk_weights =
getKnownConcreteValue(dispatch->inTopkWeights()).as<at::Tensor>();
auto is_token_in_rank =
getKnownConcreteValue(dispatch->inIsTokenInRank()).as<at::Tensor>();

auto result = doMoeDispatch(
x,
topk_idx,
topk_weights,
is_token_in_rank,
dispatch->numExperts(),
communicator_,
dispatch->backend());

expr_evaluator_.bind(dispatch->outX(), result.recv_x);
expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx);
expr_evaluator_.bind(dispatch->outTopkWeights(), result.recv_topk_weights);
expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx);
expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank);
expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank);
expr_evaluator_.bind(
dispatch->outTokensFromRank(), result.n_tokens_from_rank);
}

void HostIrEvaluator::handle(MoeCombine* combine) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

auto x = getKnownConcreteValue(combine->inX()).as<at::Tensor>();
auto topk_weights =
getKnownConcreteValue(combine->inTopkWeights()).as<at::Tensor>();
auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as<at::Tensor>();
auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as<at::Tensor>();
auto n_tokens_to_rank =
getKnownConcreteValue(combine->inTokensToRank()).as<at::Tensor>();
auto n_tokens_from_rank =
getKnownConcreteValue(combine->inTokensFromRank()).as<at::Tensor>();

auto result = doMoeCombine(
x,
topk_weights,
src_idx,
src_rank,
n_tokens_to_rank,
n_tokens_from_rank,
communicator_,
combine->backend());

expr_evaluator_.bind(combine->outX(), result.combined_x);
expr_evaluator_.bind(combine->outTopkWeights(), result.combined_topk_weights);
}

void HostIrEvaluator::handle(Wait* wait) {
Expr* expr = wait->communication();
auto* p2p_comm = dynamic_cast<P2PCommunication*>(expr);
Expand Down
2 changes: 2 additions & 0 deletions csrc/host_ir/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch {
void handle(LaunchKernel*) override;
void handle(Communication*) override;
void handle(P2PCommunication*) override;
void handle(MoeDispatch*) override;
void handle(MoeCombine*) override;
void handle(Wait*) override;
void handle(kir::ForLoop*) override;
void handle(hir::ForLoop*) override;
Expand Down
163 changes: 163 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,169 @@ std::string P2PCommunication::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

MoeDispatch::MoeDispatch(
IrBuilderPasskey passkey,
TensorView* out_x,
TensorView* out_topk_idx,
TensorView* out_topk_weights,
TensorView* out_src_idx,
TensorView* out_src_rank,
TensorView* out_n_tokens_to_rank,
TensorView* out_n_tokens_from_rank,
TensorView* in_x,
TensorView* in_topk_idx,
TensorView* in_topk_weights,
TensorView* in_is_token_in_rank,
int64_t num_experts,
CommunicatorBackend backend)
: Expr(passkey) {
addInput(in_x);
addInput(in_topk_idx);
addInput(in_topk_weights);
addInput(in_is_token_in_rank);
addOutput(out_x);
addOutput(out_topk_idx);
addOutput(out_topk_weights);
addOutput(out_src_idx);
addOutput(out_src_rank);
addOutput(out_n_tokens_to_rank);
addOutput(out_n_tokens_from_rank);
addDataAttribute(num_experts);
addDataAttribute(backend);
validate();
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MoeDispatch)

std::string MoeDispatch::toInlineString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Dispatch " << name() << " ("
<< "num_experts=" << numExperts() << ", "
<< "backend=" << backend() << ", "
<< "in=" << inX() << ", "
<< "topk_idx=" << inTopkIdx() << ", "
<< "topk_weights=" << inTopkWeights() << ", "
<< "is_token_in_rank=" << inIsTokenInRank() << ", "
<< "out=" << outX() << ")";
return ss.str();
}

std::string MoeDispatch::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

void MoeDispatch::validate() {
NVF_CHECK(numExperts() > 0, "num_experts must be positive.");
NVF_CHECK(inX()->isA<TensorView>(), "in_x must be a TensorView.");
NVF_CHECK(inTopkIdx()->isA<TensorView>(), "topk_idx must be a TensorView.");
NVF_CHECK(
inTopkIdx()->getDataType().has_value() &&
isIntegralType(*inTopkIdx()->getDataType()),
"topk_idx must be integral.");
NVF_CHECK(
inTopkWeights()->getDataType().has_value() &&
isFloatingPointType(*inTopkWeights()->getDataType()),
"topk_weights must be floating point.");
NVF_CHECK(
inIsTokenInRank()->getDataType().has_value() &&
inIsTokenInRank()->getDataType() == DataType::Bool,
"is_token_in_rank must be Bool.");
NVF_CHECK(
outTopkIdx()->getDataType().has_value() &&
isIntegralType(*outTopkIdx()->getDataType()),
"out_topk_idx must be integral.");
NVF_CHECK(
outTopkWeights()->getDataType().has_value() &&
isFloatingPointType(*outTopkWeights()->getDataType()),
"out_topk_weights must be floating point.");
NVF_CHECK(
outSrcIdx()->getDataType().has_value() &&
isIntegralType(*outSrcIdx()->getDataType()),
"out_src_idx must be integral.");
NVF_CHECK(
outSrcRank()->getDataType().has_value() &&
isIntegralType(*outSrcRank()->getDataType()),
"out_src_rank must be integral.");
NVF_CHECK(
outTokensToRank()->getDataType().has_value() &&
isIntegralType(*outTokensToRank()->getDataType()),
"out_n_tokens_to_rank must be integral.");
NVF_CHECK(
outTokensFromRank()->getDataType().has_value() &&
isIntegralType(*outTokensFromRank()->getDataType()),
"out_n_tokens_from_rank must be integral.");
}

MoeCombine::MoeCombine(
IrBuilderPasskey passkey,
TensorView* out_x,
TensorView* out_topk_weights,
TensorView* in_x,
TensorView* in_topk_weights,
TensorView* in_src_idx,
TensorView* in_src_rank,
TensorView* in_n_tokens_to_rank,
TensorView* in_n_tokens_from_rank,
CommunicatorBackend backend)
: Expr(passkey) {
addInput(in_x);
addInput(in_topk_weights);
addInput(in_src_idx);
addInput(in_src_rank);
addInput(in_n_tokens_to_rank);
addInput(in_n_tokens_from_rank);
addOutput(out_x);
addOutput(out_topk_weights);
addDataAttribute(backend);
validate();
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MoeCombine)

std::string MoeCombine::toInlineString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Combine " << name() << " ("
<< "backend=" << backend() << ", "
<< "in=" << inX() << ", "
<< "topk_weights=" << inTopkWeights() << ", "
<< "src_idx=" << inSrcIdx() << ", "
<< "src_rank=" << inSrcRank() << ", "
<< "out=" << outX() << ")";
return ss.str();
}

std::string MoeCombine::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

void MoeCombine::validate() {
NVF_CHECK(inX()->isA<TensorView>(), "in_x must be a TensorView.");
NVF_CHECK(
inTopkWeights()->getDataType().has_value() &&
isFloatingPointType(*inTopkWeights()->getDataType()),
"in_topk_weights must be floating point.");
NVF_CHECK(
inSrcIdx()->getDataType().has_value() &&
isIntegralType(*inSrcIdx()->getDataType()),
"in_src_idx must be integral.");
NVF_CHECK(
inSrcRank()->getDataType().has_value() &&
isIntegralType(*inSrcRank()->getDataType()),
"in_src_rank must be integral.");
NVF_CHECK(
inTokensToRank()->getDataType().has_value() &&
isIntegralType(*inTokensToRank()->getDataType()),
"in_n_tokens_to_rank must be integral.");
NVF_CHECK(
inTokensFromRank()->getDataType().has_value() &&
isIntegralType(*inTokensFromRank()->getDataType()),
"in_n_tokens_from_rank must be integral.");
NVF_CHECK(
outTopkWeights()->getDataType().has_value() &&
isFloatingPointType(*outTopkWeights()->getDataType()),
"out_topk_weights must be floating point.");
}

namespace {
c10::intrusive_ptr<c10d::Work> postBroadcast(
Communication* communication,
Expand Down
Loading