@@ -961,57 +961,73 @@ std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var&
961961 */
962962class TIRFuseMutator : public ExprMutator {
963963 public:
964- static IRModule Transform (const IRModule& mod) {
965- Map<GlobalVar, BaseFunc> funcs_to_keep;
966- for ( const auto & [gv, func] : mod-> functions ) {
967- // 1. If a TIR function has global symbol, we keep the function.
968- // 2. Always keep ExternFunc.
969- if (const auto * prim_func = func. as <tir::PrimFuncNode>( )) {
970- if (prim_func-> GetAttr <String>( " global_symbol " ). defined ()) {
971- funcs_to_keep .Set (gv , func);
964+ static IRModule Transform (IRModule mod) {
965+ // Collect all primitive relax functions
966+ Map<GlobalVar, Function> primitive_relax;
967+ for ( const auto & [gvar, base_func] : mod-> functions ) {
968+ // Only fuse primitive relax functions
969+ if (base_func-> HasNonzeroAttr (attr:: kPrimitive )) {
970+ if (auto func = base_func. as <relax::Function> ()) {
971+ primitive_relax .Set (gvar , func. value () );
972972 }
973- } else if (func->IsInstance <ExternFuncNode>()) {
974- funcs_to_keep.Set (gv, func);
975973 }
976974 }
975+
976+ if (primitive_relax.empty ()) {
977+ return mod;
978+ }
979+
980+ mod.CopyOnWrite ();
981+
982+ IRModule updates;
983+ std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements;
984+
977985 // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder.
978- TIRFuseMutator mutator (mod);
986+
979987 // Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_`
980- for (const auto & [gv, func] : mod->functions ) {
981- // Only fuse primitive relax functions
982- if (func->IsInstance <relax::FunctionNode>() && func->HasNonzeroAttr (attr::kPrimitive )) {
983- const auto & [prim_func, indices] = FusedTIRConstructor::GetFusedTIR (mod, gv);
984- mutator.fused_tir_funcs_ .Set (gv, prim_func);
985- if (!indices.empty ()) {
986- mutator.inplace_indices_ .Set (gv, indices);
987- }
988- }
988+ for (const auto & [old_gvar, func] : primitive_relax) {
989+ const auto & [prim_func, indices] = FusedTIRConstructor::GetFusedTIR (mod, old_gvar);
990+
991+ GlobalVar new_gvar (old_gvar->name_hint );
992+ UpdateStructInfo (new_gvar,
993+ FuncStructInfo::OpaqueFunc (StructInfoFromType (prim_func->ret_type )));
994+
995+ mod->Remove (old_gvar);
996+ updates->Add (new_gvar, prim_func);
997+ replacements[old_gvar] = Replacement{new_gvar, func, indices};
989998 }
990999
1000+ TIRFuseMutator mutator (replacements);
1001+
9911002 // Step 2. Update all non-primitive relax functions and add it, with the dependent function,
9921003 // into the new IRModule
1004+
9931005 for (const auto & [gv, func] : mod->functions ) {
994- if (func->IsInstance <relax::FunctionNode>() && !func->HasNonzeroAttr (attr::kPrimitive )) {
1006+ if (func->IsInstance <relax::FunctionNode>()) {
1007+ ICHECK (!func->HasNonzeroAttr (attr::kPrimitive ))
1008+ << " Module should not contain any primitive relax functions at this point" ;
9951009 relax::Function update_func = Downcast<Function>(mutator.VisitExpr (func));
996- mutator.builder_ ->AddFunction (update_func, gv->name_hint );
997- }
998- }
999-
1000- // Step 3. Add all functions that need to be kept.
1001- auto modified_mod = mutator.builder_ ->GetContextIRModule ();
1002- for (const auto & [gv, func] : funcs_to_keep) {
1003- if (!modified_mod->ContainGlobalVar (gv->name_hint )) {
1004- modified_mod->Add (gv, func);
1010+ if (!update_func.same_as (func)) {
1011+ updates->Add (gv, update_func);
1012+ }
10051013 }
10061014 }
10071015
1008- // Step 4. Copy over module attributes and return.
1009- if ( mod->attrs . defined ()) modified_mod = WithAttrs (modified_mod, mod-> attrs -> dict );
1010- return modified_mod ;
1016+ // Step 4. Copy over updated functions and return.
1017+ mod->Update (updates );
1018+ return mod ;
10111019 }
10121020
10131021 private:
1014- explicit TIRFuseMutator (const IRModule& mod) : mod_(mod) {}
1022+ struct Replacement {
1023+ GlobalVar fused_tir_gvar;
1024+ Function original_function;
1025+ Array<Integer> inplace_indices;
1026+ };
1027+
1028+ explicit TIRFuseMutator (
1029+ std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements)
1030+ : replacements_(replacements) {}
10151031
10161032 using ExprMutator::VisitExpr_;
10171033
@@ -1035,92 +1051,86 @@ class TIRFuseMutator : public ExprMutator {
10351051
10361052 Call call = Downcast<Call>(builder_->Normalize (ExprMutator::VisitExpr_ (op)));
10371053
1038- if (call->op ->IsInstance <GlobalVarNode>()) {
1039- // Case 1. It is a relax cross function call
1040- GlobalVar old_gv = Downcast<GlobalVar>(call->op );
1041- auto relax_func = Downcast<Function>(mod_->Lookup (old_gv));
1042- auto it = fused_tir_funcs_.find (old_gv);
1043- if (it != fused_tir_funcs_.end ()) {
1044- const tir::PrimFunc& fused_tir = (*it).second ;
1045- // Case 1.1. It calls a primitive relax function, update the call into a call_tir
1046- GlobalVar fused_tir_gv = this ->builder_ ->AddFunction (fused_tir, old_gv->name_hint );
1047- // Step a. Flatten all args since call_tir does not support Tuple value.
1048- Array<Expr> arg_list;
1049- Array<PrimExpr> tir_vars;
1050- for (size_t i = 0 ; i < call->args .size (); ++i) {
1051- auto arg = call->args [i];
1052- auto sinfo = GetStructInfo (arg);
1053-
1054- ICHECK (!relax_func->params [i]->struct_info_ ->IsInstance <TupleStructInfoNode>() &&
1055- !sinfo.as <TupleStructInfoNode>())
1056- << " InternalError: "
1057- << " All tuple parameters should be expanded before this point in FuseTIR. "
1058- << " However, argument " << arg << " with struct info " << arg->struct_info_
1059- << " is passed as argument " << i << " to Primitive Relax function " << old_gv
1060- << " , which expects parameter " << relax_func->params [i] << " to have struct info "
1061- << relax_func->params [i]->struct_info_ ;
1062-
1063- if (const auto * shape = sinfo.as <ShapeStructInfoNode>()) {
1064- CHECK (shape->values .defined ())
1065- << " FuseTIR requires all shape input has struct_info value." ;
1066- for (const PrimExpr& prim_value : shape->values .value ()) {
1067- CHECK (prim_value->IsInstance <tir::VarNode>())
1068- << " All shape inputs are expected to be single tir var." ;
1069- tir_vars.push_back (prim_value);
1070- }
1071- } else if (const auto * prim_value = sinfo.as <PrimStructInfoNode>()) {
1072- CHECK (prim_value->value .defined ())
1073- << " FuseTIR requires all R.Prim arguments to have a known value." ;
1074- PrimExpr expr = prim_value->value .value ();
1075- CHECK (expr->IsInstance <tir::VarNode>()) << " FuseTIR currently requires all R.Prim "
1076- " arguments to provide a single tir::Var." ;
1077- tir_vars.push_back (expr);
1078-
1079- } else {
1080- arg_list.push_back (arg);
1081- }
1082- }
1083- // Step b. Create call_tir or call_tir_inplace
1084- Array<Expr> call_args = {fused_tir_gv, Tuple (arg_list)};
1085- if (!tir_vars.empty ()) {
1086- call_args.push_back (ShapeExpr (tir_vars));
1087- }
1088- Op call_op = call_tir_op_;
1089- Attrs call_attrs = call->attrs ;
1090- if (auto it = inplace_indices_.find (old_gv); it != inplace_indices_.end ()) {
1091- call_op = call_tir_inplace_op_;
1092- auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
1093- inplace_attrs->inplace_indices = (*it).second ;
1094- call_attrs = Attrs (inplace_attrs);
1054+ auto opt_gvar = call->op .as <GlobalVar>();
1055+ if (!opt_gvar) {
1056+ // Case 1. The Call isn't a relax-to-relax function call, no need to update.
1057+ return call;
1058+ }
1059+ GlobalVar old_gvar = opt_gvar.value ();
1060+
1061+ auto it = replacements_.find (old_gvar);
1062+ if (it == replacements_.end ()) {
1063+ // Case 2. The callee function is not a primitive relax
1064+ // function, no need to update.
1065+ return call;
1066+ }
1067+ const Replacement& replacement = it->second ;
1068+ const GlobalVar& fused_tir_gv = replacement.fused_tir_gvar ;
1069+ const Function& relax_func = replacement.original_function ;
1070+
1071+ // Case 3. It calls a primitive relax function, update the call
1072+ // into a call_tir or call_tir_inplace.
1073+
1074+ // Step a. Collect all relax/symbolic arguments. Tuple arguments
1075+ // are not supported by PrimFunc, so this step verifies that
1076+ // ExpandTupleArguments has already removed them.
1077+ Array<Expr> arg_list;
1078+ Array<PrimExpr> tir_vars;
1079+ for (size_t i = 0 ; i < call->args .size (); ++i) {
1080+ auto arg = call->args [i];
1081+ auto sinfo = GetStructInfo (arg);
1082+
1083+ ICHECK (!relax_func->params [i]->struct_info_ ->IsInstance <TupleStructInfoNode>() &&
1084+ !sinfo.as <TupleStructInfoNode>())
1085+ << " InternalError: "
1086+ << " All tuple parameters should be expanded before this point in FuseTIR. "
1087+ << " However, argument " << arg << " with struct info " << arg->struct_info_
1088+ << " is passed as argument " << i << " to Primitive Relax function " << old_gvar
1089+ << " , which expects parameter " << relax_func->params [i] << " to have struct info "
1090+ << relax_func->params [i]->struct_info_ ;
1091+
1092+ if (const auto * shape = sinfo.as <ShapeStructInfoNode>()) {
1093+ CHECK (shape->values .defined ()) << " FuseTIR requires all shape input has struct_info value." ;
1094+ for (const PrimExpr& prim_value : shape->values .value ()) {
1095+ CHECK (prim_value->IsInstance <tir::VarNode>())
1096+ << " All shape inputs are expected to be single tir var." ;
1097+ tir_vars.push_back (prim_value);
10951098 }
1096- return Call (call_op, call_args, call_attrs, {GetStructInfo (call)});
1099+ } else if (const auto * prim_value = sinfo.as <PrimStructInfoNode>()) {
1100+ CHECK (prim_value->value .defined ())
1101+ << " FuseTIR requires all R.Prim arguments to have a known value." ;
1102+ PrimExpr expr = prim_value->value .value ();
1103+ CHECK (expr->IsInstance <tir::VarNode>()) << " FuseTIR currently requires all R.Prim "
1104+ " arguments to provide a single tir::Var." ;
1105+ tir_vars.push_back (expr);
1106+
10971107 } else {
1098- // Case 1.2. The callee function is not primitive, nothing to do.
1099- return call;
1100- }
1101- } else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) {
1102- // Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc.
1103- if (const auto * gv = call->args [0 ].as <GlobalVarNode>()) {
1104- tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup (GetRef<GlobalVar>(gv)));
1105- GlobalVar new_gv = this ->builder_ ->AddFunction (func, gv->name_hint );
1106- Array<Expr> new_args = call->args ;
1107- new_args.Set (0 , new_gv);
1108- return Call (call->op , new_args, call->attrs , call->sinfo_args , call->span );
1108+ arg_list.push_back (arg);
11091109 }
11101110 }
11111111
1112- // Case 3. CallNode in other types. Leave it as it is.
1113- return call;
1112+ // Step b. Create call_tir or call_tir_inplace
1113+ Array<Expr> call_args = {fused_tir_gv, Tuple (arg_list)};
1114+ if (!tir_vars.empty ()) {
1115+ call_args.push_back (ShapeExpr (tir_vars));
1116+ }
1117+ Op call_op = call_tir_op_;
1118+ Attrs call_attrs = call->attrs ;
1119+ if (replacement.inplace_indices .size ()) {
1120+ call_op = call_tir_inplace_op_;
1121+ auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
1122+ inplace_attrs->inplace_indices = replacement.inplace_indices ;
1123+ call_attrs = Attrs (inplace_attrs);
1124+ }
1125+ return Call (call_op, call_args, call_attrs, {GetStructInfo (call)});
11141126 }
11151127
11161128 private:
1117- /* ! \brief The IRModule */
1118- const IRModule& mod_;
1119- /* ! \brief The map from global var of primitive relax function to generated prim func. */
1120- Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
1121- /* ! \brief The map from global var of primitive relax function to in-place indices
1122- * (if there are any). */
1123- Map<GlobalVar, Array<Integer>> inplace_indices_;
1129+ /* ! \brief The map from global var to how it should be replaced
1130+ *
1131+ * Has one entry for each primitive relax function in the IRModule.
1132+ */
1133+ std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements_;
11241134};
11251135
11261136IRModule FuseTIR (IRModule mod) {
@@ -1142,6 +1152,7 @@ Pass FuseTIR() {
11421152 ExpandTupleArguments (),
11431153 RemoveUnusedParameters (),
11441154 inner_pass,
1155+ DeadCodeElimination (),
11451156 },
11461157 " FuseTIR" );
11471158}
0 commit comments