Skip to content

Commit 274c368

Browse files
authored
[Bugfix][Transform] Keep private non-primitive functions in FuseTIR (#16565)
Prior to this commit, private non-primitive relax functions would be discarded by `FuseTIR`. If any calls to these functions exist, the resulting `IRModule` would be ill-formed. This commit updates `FuseTIR` so that it only applies updates to functions with `attr::kPrimitive`, and calls into those functions. To retain backwards compatibility, `DeadCodeElimination` is applied as a post-processing step.
1 parent 2b813ec commit 274c368

3 files changed

Lines changed: 187 additions & 113 deletions

File tree

include/tvm/relax/transform.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,12 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2);
578578
*
579579
* Any binding blocks that are left empty will be removed by the normalizer.
580580
*
581+
* \param entry_functions Names of functions that should be considered
582+
* as entry points, in addition to any externally exposed functions.
583+
*
581584
* \return The Pass.
582585
*/
583-
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
586+
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions = {});
584587

585588
/*!
586589
* \brief Pass that changes calls to operators that can be done in-place

src/relax/transform/fuse_tir.cc

Lines changed: 123 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -961,57 +961,73 @@ std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var&
961961
*/
962962
class TIRFuseMutator : public ExprMutator {
963963
public:
964-
static IRModule Transform(const IRModule& mod) {
965-
Map<GlobalVar, BaseFunc> funcs_to_keep;
966-
for (const auto& [gv, func] : mod->functions) {
967-
// 1. If a TIR function has global symbol, we keep the function.
968-
// 2. Always keep ExternFunc.
969-
if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
970-
if (prim_func->GetAttr<String>("global_symbol").defined()) {
971-
funcs_to_keep.Set(gv, func);
964+
static IRModule Transform(IRModule mod) {
965+
// Collect all primitive relax functions
966+
Map<GlobalVar, Function> primitive_relax;
967+
for (const auto& [gvar, base_func] : mod->functions) {
968+
// Only fuse primitive relax functions
969+
if (base_func->HasNonzeroAttr(attr::kPrimitive)) {
970+
if (auto func = base_func.as<relax::Function>()) {
971+
primitive_relax.Set(gvar, func.value());
972972
}
973-
} else if (func->IsInstance<ExternFuncNode>()) {
974-
funcs_to_keep.Set(gv, func);
975973
}
976974
}
975+
976+
if (primitive_relax.empty()) {
977+
return mod;
978+
}
979+
980+
mod.CopyOnWrite();
981+
982+
IRModule updates;
983+
std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements;
984+
977985
// Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder.
978-
TIRFuseMutator mutator(mod);
986+
979987
// Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_`
980-
for (const auto& [gv, func] : mod->functions) {
981-
// Only fuse primitive relax functions
982-
if (func->IsInstance<relax::FunctionNode>() && func->HasNonzeroAttr(attr::kPrimitive)) {
983-
const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, gv);
984-
mutator.fused_tir_funcs_.Set(gv, prim_func);
985-
if (!indices.empty()) {
986-
mutator.inplace_indices_.Set(gv, indices);
987-
}
988-
}
988+
for (const auto& [old_gvar, func] : primitive_relax) {
989+
const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, old_gvar);
990+
991+
GlobalVar new_gvar(old_gvar->name_hint);
992+
UpdateStructInfo(new_gvar,
993+
FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)));
994+
995+
mod->Remove(old_gvar);
996+
updates->Add(new_gvar, prim_func);
997+
replacements[old_gvar] = Replacement{new_gvar, func, indices};
989998
}
990999

1000+
TIRFuseMutator mutator(replacements);
1001+
9911002
// Step 2. Update all non-primitive relax functions and add it, with the dependent function,
9921003
// into the new IRModule
1004+
9931005
for (const auto& [gv, func] : mod->functions) {
994-
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
1006+
if (func->IsInstance<relax::FunctionNode>()) {
1007+
ICHECK(!func->HasNonzeroAttr(attr::kPrimitive))
1008+
<< "Module should not contain any primitive relax functions at this point";
9951009
relax::Function update_func = Downcast<Function>(mutator.VisitExpr(func));
996-
mutator.builder_->AddFunction(update_func, gv->name_hint);
997-
}
998-
}
999-
1000-
// Step 3. Add all functions that need to be kept.
1001-
auto modified_mod = mutator.builder_->GetContextIRModule();
1002-
for (const auto& [gv, func] : funcs_to_keep) {
1003-
if (!modified_mod->ContainGlobalVar(gv->name_hint)) {
1004-
modified_mod->Add(gv, func);
1010+
if (!update_func.same_as(func)) {
1011+
updates->Add(gv, update_func);
1012+
}
10051013
}
10061014
}
10071015

1008-
// Step 4. Copy over module attributes and return.
1009-
if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict);
1010-
return modified_mod;
1016+
// Step 4. Copy over updated functions and return.
1017+
mod->Update(updates);
1018+
return mod;
10111019
}
10121020

10131021
private:
1014-
explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {}
1022+
struct Replacement {
1023+
GlobalVar fused_tir_gvar;
1024+
Function original_function;
1025+
Array<Integer> inplace_indices;
1026+
};
1027+
1028+
explicit TIRFuseMutator(
1029+
std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements)
1030+
: replacements_(replacements) {}
10151031

10161032
using ExprMutator::VisitExpr_;
10171033

@@ -1035,92 +1051,86 @@ class TIRFuseMutator : public ExprMutator {
10351051

10361052
Call call = Downcast<Call>(builder_->Normalize(ExprMutator::VisitExpr_(op)));
10371053

1038-
if (call->op->IsInstance<GlobalVarNode>()) {
1039-
// Case 1. It is a relax cross function call
1040-
GlobalVar old_gv = Downcast<GlobalVar>(call->op);
1041-
auto relax_func = Downcast<Function>(mod_->Lookup(old_gv));
1042-
auto it = fused_tir_funcs_.find(old_gv);
1043-
if (it != fused_tir_funcs_.end()) {
1044-
const tir::PrimFunc& fused_tir = (*it).second;
1045-
// Case 1.1. It calls a primitive relax function, update the call into a call_tir
1046-
GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint);
1047-
// Step a. Flatten all args since call_tir does not support Tuple value.
1048-
Array<Expr> arg_list;
1049-
Array<PrimExpr> tir_vars;
1050-
for (size_t i = 0; i < call->args.size(); ++i) {
1051-
auto arg = call->args[i];
1052-
auto sinfo = GetStructInfo(arg);
1053-
1054-
ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
1055-
!sinfo.as<TupleStructInfoNode>())
1056-
<< "InternalError: "
1057-
<< "All tuple parameters should be expanded before this point in FuseTIR. "
1058-
<< "However, argument " << arg << " with struct info " << arg->struct_info_
1059-
<< " is passed as argument " << i << " to Primitive Relax function " << old_gv
1060-
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
1061-
<< relax_func->params[i]->struct_info_;
1062-
1063-
if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
1064-
CHECK(shape->values.defined())
1065-
<< "FuseTIR requires all shape input has struct_info value.";
1066-
for (const PrimExpr& prim_value : shape->values.value()) {
1067-
CHECK(prim_value->IsInstance<tir::VarNode>())
1068-
<< "All shape inputs are expected to be single tir var.";
1069-
tir_vars.push_back(prim_value);
1070-
}
1071-
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
1072-
CHECK(prim_value->value.defined())
1073-
<< "FuseTIR requires all R.Prim arguments to have a known value.";
1074-
PrimExpr expr = prim_value->value.value();
1075-
CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently requires all R.Prim "
1076-
"arguments to provide a single tir::Var.";
1077-
tir_vars.push_back(expr);
1078-
1079-
} else {
1080-
arg_list.push_back(arg);
1081-
}
1082-
}
1083-
// Step b. Create call_tir or call_tir_inplace
1084-
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
1085-
if (!tir_vars.empty()) {
1086-
call_args.push_back(ShapeExpr(tir_vars));
1087-
}
1088-
Op call_op = call_tir_op_;
1089-
Attrs call_attrs = call->attrs;
1090-
if (auto it = inplace_indices_.find(old_gv); it != inplace_indices_.end()) {
1091-
call_op = call_tir_inplace_op_;
1092-
auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
1093-
inplace_attrs->inplace_indices = (*it).second;
1094-
call_attrs = Attrs(inplace_attrs);
1054+
auto opt_gvar = call->op.as<GlobalVar>();
1055+
if (!opt_gvar) {
1056+
// Case 1. The Call isn't a relax-to-relax function call, no need to update.
1057+
return call;
1058+
}
1059+
GlobalVar old_gvar = opt_gvar.value();
1060+
1061+
auto it = replacements_.find(old_gvar);
1062+
if (it == replacements_.end()) {
1063+
// Case 2. The callee function is not a primitive relax
1064+
// function, no need to update.
1065+
return call;
1066+
}
1067+
const Replacement& replacement = it->second;
1068+
const GlobalVar& fused_tir_gv = replacement.fused_tir_gvar;
1069+
const Function& relax_func = replacement.original_function;
1070+
1071+
// Case 3. It calls a primitive relax function, update the call
1072+
// into a call_tir or call_tir_inplace.
1073+
1074+
// Step a. Collect all relax/symbolic arguments. Tuple arguments
1075+
// are not supported by PrimFunc, so this step verifies that
1076+
// ExpandTupleArguments has already removed them.
1077+
Array<Expr> arg_list;
1078+
Array<PrimExpr> tir_vars;
1079+
for (size_t i = 0; i < call->args.size(); ++i) {
1080+
auto arg = call->args[i];
1081+
auto sinfo = GetStructInfo(arg);
1082+
1083+
ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
1084+
!sinfo.as<TupleStructInfoNode>())
1085+
<< "InternalError: "
1086+
<< "All tuple parameters should be expanded before this point in FuseTIR. "
1087+
<< "However, argument " << arg << " with struct info " << arg->struct_info_
1088+
<< " is passed as argument " << i << " to Primitive Relax function " << old_gvar
1089+
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
1090+
<< relax_func->params[i]->struct_info_;
1091+
1092+
if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
1093+
CHECK(shape->values.defined()) << "FuseTIR requires all shape input has struct_info value.";
1094+
for (const PrimExpr& prim_value : shape->values.value()) {
1095+
CHECK(prim_value->IsInstance<tir::VarNode>())
1096+
<< "All shape inputs are expected to be single tir var.";
1097+
tir_vars.push_back(prim_value);
10951098
}
1096-
return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
1099+
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
1100+
CHECK(prim_value->value.defined())
1101+
<< "FuseTIR requires all R.Prim arguments to have a known value.";
1102+
PrimExpr expr = prim_value->value.value();
1103+
CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently requires all R.Prim "
1104+
"arguments to provide a single tir::Var.";
1105+
tir_vars.push_back(expr);
1106+
10971107
} else {
1098-
// Case 1.2. The callee function is not primitive, nothing to do.
1099-
return call;
1100-
}
1101-
} else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) {
1102-
// Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc.
1103-
if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
1104-
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
1105-
GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
1106-
Array<Expr> new_args = call->args;
1107-
new_args.Set(0, new_gv);
1108-
return Call(call->op, new_args, call->attrs, call->sinfo_args, call->span);
1108+
arg_list.push_back(arg);
11091109
}
11101110
}
11111111

1112-
// Case 3. CallNode in other types. Leave it as it is.
1113-
return call;
1112+
// Step b. Create call_tir or call_tir_inplace
1113+
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
1114+
if (!tir_vars.empty()) {
1115+
call_args.push_back(ShapeExpr(tir_vars));
1116+
}
1117+
Op call_op = call_tir_op_;
1118+
Attrs call_attrs = call->attrs;
1119+
if (replacement.inplace_indices.size()) {
1120+
call_op = call_tir_inplace_op_;
1121+
auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
1122+
inplace_attrs->inplace_indices = replacement.inplace_indices;
1123+
call_attrs = Attrs(inplace_attrs);
1124+
}
1125+
return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
11141126
}
11151127

11161128
private:
1117-
/*! \brief The IRModule */
1118-
const IRModule& mod_;
1119-
/*! \brief The map from global var of primitive relax function to generated prim func. */
1120-
Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
1121-
/*! \brief The map from global var of primitive relax function to in-place indices
1122-
* (if there are any). */
1123-
Map<GlobalVar, Array<Integer>> inplace_indices_;
1129+
/*! \brief The map from global var to how it should be replaced
1130+
*
1131+
* Has one entry for each primitive relax function in the IRModule.
1132+
*/
1133+
std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements_;
11241134
};
11251135

11261136
IRModule FuseTIR(IRModule mod) {
@@ -1142,6 +1152,7 @@ Pass FuseTIR() {
11421152
ExpandTupleArguments(),
11431153
RemoveUnusedParameters(),
11441154
inner_pass,
1155+
DeadCodeElimination(),
11451156
},
11461157
"FuseTIR");
11471158
}

tests/python/relax/test_transform_fuse_tir.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,5 +2254,65 @@ def main(
22542254
_check(Module, Expected)
22552255

22562256

2257+
def test_private_nonprimitive_func():
2258+
"""Input IRModule may contain calls to non-primitive functions
2259+
2260+
This is a regression test. Prior implementations did not preserve
2261+
relax-to-relax function calls.
2262+
"""
2263+
2264+
@I.ir_module
2265+
class Before:
2266+
@R.function
2267+
def main(
2268+
input_ids: R.Tensor((1,), dtype="int32"),
2269+
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
2270+
) -> R.Tensor((1, 4096), dtype="float16"):
2271+
cls = Before
2272+
with R.dataflow():
2273+
gv = cls.fused_func(input_ids, input_embeds)
2274+
R.output(gv)
2275+
return gv
2276+
2277+
@R.function(private=True)
2278+
def fused_func(
2279+
input_ids: R.Tensor((1,), dtype="int32"),
2280+
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
2281+
) -> R.Tensor((1, 4096), dtype="float16"):
2282+
cls = Before
2283+
with R.dataflow():
2284+
lv = R.call_tir(
2285+
cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
2286+
)
2287+
gv = R.call_tir(
2288+
cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16")
2289+
)
2290+
R.output(gv)
2291+
return gv
2292+
2293+
@T.prim_func(private=True)
2294+
def add(
2295+
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
2296+
Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
2297+
):
2298+
for i, j in T.grid(T.int64(4096), T.int64(4096)):
2299+
with T.block("add"):
2300+
vi, vj = T.axis.remap("SS", [i, j])
2301+
Out[vi, vj] = A[vi, vj] + T.float16(1.0)
2302+
2303+
@T.prim_func(private=True)
2304+
def take(
2305+
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
2306+
B: T.Buffer((T.int64(1),), "int32"),
2307+
T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"),
2308+
):
2309+
for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
2310+
with T.block("T_take"):
2311+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
2312+
T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]
2313+
2314+
_check(Before, Before)
2315+
2316+
22572317
if __name__ == "__main__":
22582318
tvm.testing.main()

0 commit comments

Comments
 (0)