Skip to content

Commit 33c3d6a

Browse files
committed
[Transform] Preserve all reachable functions in FuseTIR
Prior to this commit, the `FuseTIR` pass had custom logic to determine whether a `PrimFunc` should be kept in the output `IRModule`. This commit replaces this check in `FuseTIR` with a post-processing by `DeadCodeElimination`.
1 parent bbbc895 commit 33c3d6a

File tree

1 file changed

+82
-94
lines changed

1 file changed

+82
-94
lines changed

src/relax/transform/fuse_tir.cc

Lines changed: 82 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -878,54 +878,31 @@ std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var&
878878
*/
879879
class TIRFuseMutator : public ExprMutator {
880880
public:
881-
static IRModule Transform(const IRModule& mod) {
882-
Map<GlobalVar, BaseFunc> funcs_to_keep;
883-
for (const auto& [gv, func] : mod->functions) {
884-
// 1. If a TIR function has global symbol, we keep the function.
885-
// 2. Always keep ExternFunc.
886-
if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
887-
if (prim_func->GetAttr<String>("global_symbol").defined()) {
888-
funcs_to_keep.Set(gv, func);
889-
}
890-
} else if (func->IsInstance<ExternFuncNode>()) {
891-
funcs_to_keep.Set(gv, func);
892-
}
893-
}
894-
// Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder.
895-
TIRFuseMutator mutator(mod);
896-
// Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_`
897-
for (const auto& [gv, func] : mod->functions) {
898-
// Only fuse primitive relax functions
899-
if (func->IsInstance<relax::FunctionNode>() && func->HasNonzeroAttr(attr::kPrimitive)) {
900-
tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
901-
mutator.fused_tir_funcs_.Set(gv, fused_tir);
902-
}
903-
}
904-
905-
// Step 2. Update all non-primitive relax functions and add it, with the dependent function,
906-
// into the new IRModule
907-
for (const auto& [gv, func] : mod->functions) {
881+
static IRModule Transform(const IRModule& orig_mod) {
882+
// Track the original IRModule separate from the BlockBuilder's
883+
// partially transformed module. Preserving access to the
884+
// pre-fused Primitive function is necessary for context-dependent
885+
// fusion (e.g. producing two fused versions of the same primitive
886+
// function, depending on whether it uses `R.call_tir` or
887+
// `R.call_tir_inplace`).
888+
TIRFuseMutator mutator(orig_mod);
889+
890+
// Any non-primitive relax functions should be inspected for calls
891+
// into primitve relax functions. The primitive relax functions
892+
// themselves will be updated as part of `mutator.VisitExpr`,
893+
// while the non-primitive relax functions are updated here.
894+
for (const auto& [gv, func] : orig_mod->functions) {
908895
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
909896
relax::Function update_func = Downcast<Function>(mutator.VisitExpr(func));
910-
mutator.builder_->AddFunction(update_func, gv->name_hint);
897+
mutator.builder_->UpdateFunction(gv, update_func);
911898
}
912899
}
913900

914-
// Step 3. Add all functions that need to be kept.
915-
auto modified_mod = mutator.builder_->GetContextIRModule();
916-
for (const auto& [gv, func] : funcs_to_keep) {
917-
if (!modified_mod->ContainGlobalVar(gv->name_hint)) {
918-
modified_mod->Add(gv, func);
919-
}
920-
}
921-
922-
// Step 4. Copy over module attributes and return.
923-
if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict);
924-
return modified_mod;
901+
return mutator.builder_->Finalize();
925902
}
926903

927904
private:
928-
explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {}
905+
explicit TIRFuseMutator(const IRModule& mod) : ExprMutator(mod), orig_mod_(mod) {}
929906

930907
using ExprMutator::VisitExpr_;
931908

@@ -950,63 +927,67 @@ class TIRFuseMutator : public ExprMutator {
950927

951928
if (call->op->IsInstance<GlobalVarNode>()) {
952929
// Case 1. It is a relax cross function call
953-
GlobalVar old_gv = Downcast<GlobalVar>(call->op);
954-
auto relax_func = Downcast<Function>(mod_->Lookup(old_gv));
955-
auto it = fused_tir_funcs_.find(old_gv);
956-
if (it != fused_tir_funcs_.end()) {
957-
const tir::PrimFunc& fused_tir = (*it).second;
958-
// Case 1.1. It calls a primitive relax function, update the call into a call_tir
959-
GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint);
960-
// Step a. Flatten all args since call_tir does not support Tuple value.
961-
Array<Expr> arg_list;
962-
Array<PrimExpr> tir_vars;
963-
for (size_t i = 0; i < call->args.size(); ++i) {
964-
auto arg = call->args[i];
965-
auto sinfo = GetStructInfo(arg);
966-
967-
ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
968-
!sinfo.as<TupleStructInfoNode>())
969-
<< "InternalError: "
970-
<< "All tuple parameters should be expanded before this point in FuseTIR. "
971-
<< "However, argument " << arg << " with struct info " << arg->struct_info_
972-
<< " is passed as argument " << i << " to Primitive Relax function " << old_gv
973-
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
974-
<< relax_func->params[i]->struct_info_;
975-
976-
if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
977-
CHECK(shape->values.defined())
978-
<< "FuseTIR requires all shape input has struct_info value.";
979-
for (const PrimExpr& prim_value : shape->values.value()) {
980-
CHECK(prim_value->IsInstance<tir::VarNode>())
981-
<< "All shape inputs are expected to be single tir var.";
982-
tir_vars.push_back(prim_value);
983-
}
984-
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
985-
CHECK(prim_value->value.defined())
986-
<< "FuseTIR requires all R.Prim arguments to have a known value.";
987-
PrimExpr expr = prim_value->value.value();
988-
CHECK(expr->IsInstance<tir::VarNode>())
989-
<< "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var.";
990-
tir_vars.push_back(expr);
991-
992-
} else {
993-
arg_list.push_back(arg);
930+
GlobalVar callee_gvar = Downcast<GlobalVar>(call->op);
931+
auto relax_func = Downcast<Function>(orig_mod_->Lookup(callee_gvar));
932+
933+
if (!relax_func->HasNonzeroAttr(attr::kPrimitive)) {
934+
// The callee is not a primitive function, no need to fuse.
935+
return call;
936+
}
937+
938+
tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(orig_mod_, callee_gvar);
939+
940+
// Case 1.1. It calls a primitive relax function, update the call into a call_tir
941+
// GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint);
942+
this->builder_->UpdateFunction(callee_gvar, fused_tir);
943+
944+
// Step a. Flatten all args since call_tir does not support Tuple value.
945+
Array<Expr> arg_list;
946+
Array<PrimExpr> tir_vars;
947+
for (size_t i = 0; i < call->args.size(); ++i) {
948+
auto arg = call->args[i];
949+
auto sinfo = GetStructInfo(arg);
950+
951+
ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
952+
!sinfo.as<TupleStructInfoNode>())
953+
<< "InternalError: "
954+
<< "All tuple parameters should be expanded before this point in FuseTIR. "
955+
<< "However, argument " << arg << " with struct info " << arg->struct_info_
956+
<< " is passed as argument " << i << " to Primitive Relax function " << callee_gvar
957+
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
958+
<< relax_func->params[i]->struct_info_;
959+
960+
if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
961+
CHECK(shape->values.defined())
962+
<< "FuseTIR requires all shape input has struct_info value.";
963+
for (const PrimExpr& prim_value : shape->values.value()) {
964+
CHECK(prim_value->IsInstance<tir::VarNode>())
965+
<< "All shape inputs are expected to be single tir var.";
966+
tir_vars.push_back(prim_value);
994967
}
968+
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
969+
CHECK(prim_value->value.defined())
970+
<< "FuseTIR requires all R.Prim arguments to have a known value.";
971+
PrimExpr expr = prim_value->value.value();
972+
CHECK(expr->IsInstance<tir::VarNode>())
973+
<< "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var.";
974+
tir_vars.push_back(expr);
975+
976+
} else {
977+
arg_list.push_back(arg);
995978
}
996-
// Step b. Create call_tir
997-
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
998-
if (!tir_vars.empty()) {
999-
call_args.push_back(ShapeExpr(tir_vars));
1000-
}
1001-
return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)});
1002-
} else {
1003-
// Case 1.2. The callee function is not primitive, nothing to do.
1004-
return call;
1005979
}
980+
// Step b. Create call_tir
981+
Array<Expr> call_args = {callee_gvar, Tuple(arg_list)};
982+
if (!tir_vars.empty()) {
983+
call_args.push_back(ShapeExpr(tir_vars));
984+
}
985+
return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)});
986+
1006987
} else if (call->op == call_tir_op_) {
1007988
// Case 2. It is a call_tir, re-emit the PrimFunc.
1008989
if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
1009-
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
990+
tir::PrimFunc func = Downcast<tir::PrimFunc>(orig_mod_->Lookup(GetRef<GlobalVar>(gv)));
1010991
GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
1011992
Array<Expr> new_args = call->args;
1012993
new_args.Set(0, new_gv);
@@ -1020,9 +1001,7 @@ class TIRFuseMutator : public ExprMutator {
10201001

10211002
private:
10221003
/*! \brief The IRModule */
1023-
const IRModule& mod_;
1024-
/*! \brief The map from global var of primitive relax function to generated prim func. */
1025-
Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
1004+
const IRModule& orig_mod_;
10261005
};
10271006

10281007
IRModule FuseTIR(IRModule mod) {
@@ -1044,6 +1023,15 @@ Pass FuseTIR() {
10441023
ExpandTupleArguments(),
10451024
RemoveUnusedParameters(),
10461025
inner_pass,
1026+
// At each call site, TIRFuseMutator inserts calls to
1027+
// attr::kPrimitive relax functions with calls to the fused
1028+
// TIR functions. However, until we reach the end of the
1029+
// module, we don't know whether the unfused TIR functions
1030+
// were called from contexts other than the attr::kPrimitive
1031+
// relax function, and cannot remove them entirely.
1032+
// Post-processing with `DeadCodeElimination` ensures that
1033+
// we do not keep any unreachable pre-fused TIR PrimFuncs.
1034+
DeadCodeElimination({}),
10471035
},
10481036
"FuseTIR");
10491037
}

0 commit comments

Comments
 (0)