@@ -878,54 +878,31 @@ std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var&
878878 */
879879class TIRFuseMutator : public ExprMutator {
880880 public:
881- static IRModule Transform (const IRModule& mod) {
882- Map<GlobalVar, BaseFunc> funcs_to_keep;
883- for (const auto & [gv, func] : mod->functions ) {
884- // 1. If a TIR function has global symbol, we keep the function.
885- // 2. Always keep ExternFunc.
886- if (const auto * prim_func = func.as <tir::PrimFuncNode>()) {
887- if (prim_func->GetAttr <String>(" global_symbol" ).defined ()) {
888- funcs_to_keep.Set (gv, func);
889- }
890- } else if (func->IsInstance <ExternFuncNode>()) {
891- funcs_to_keep.Set (gv, func);
892- }
893- }
894- // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder.
895- TIRFuseMutator mutator (mod);
896- // Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_`
897- for (const auto & [gv, func] : mod->functions ) {
898- // Only fuse primitive relax functions
899- if (func->IsInstance <relax::FunctionNode>() && func->HasNonzeroAttr (attr::kPrimitive )) {
900- tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR (mod, gv);
901- mutator.fused_tir_funcs_ .Set (gv, fused_tir);
902- }
903- }
904-
905- // Step 2. Update all non-primitive relax functions and add it, with the dependent function,
906- // into the new IRModule
907- for (const auto & [gv, func] : mod->functions ) {
881+ static IRModule Transform (const IRModule& orig_mod) {
882+ // Track the original IRModule separate from the BlockBuilder's
883+ // partially transformed module. Preserving access to the
884+ // pre-fused Primitive function is necessary for context-dependent
885+ // fusion (e.g. producing two fused versions of the same primitive
886+ // function, depending on whether it uses `R.call_tir` or
887+ // `R.call_tir_inplace`).
888+ TIRFuseMutator mutator (orig_mod);
889+
890+ // Any non-primitive relax functions should be inspected for calls
891+ // into primitve relax functions. The primitive relax functions
892+ // themselves will be updated as part of `mutator.VisitExpr`,
893+ // while the non-primitive relax functions are updated here.
894+ for (const auto & [gv, func] : orig_mod->functions ) {
908895 if (func->IsInstance <relax::FunctionNode>() && !func->HasNonzeroAttr (attr::kPrimitive )) {
909896 relax::Function update_func = Downcast<Function>(mutator.VisitExpr (func));
910- mutator.builder_ ->AddFunction (update_func, gv-> name_hint );
897+ mutator.builder_ ->UpdateFunction (gv, update_func );
911898 }
912899 }
913900
914- // Step 3. Add all functions that need to be kept.
915- auto modified_mod = mutator.builder_ ->GetContextIRModule ();
916- for (const auto & [gv, func] : funcs_to_keep) {
917- if (!modified_mod->ContainGlobalVar (gv->name_hint )) {
918- modified_mod->Add (gv, func);
919- }
920- }
921-
922- // Step 4. Copy over module attributes and return.
923- if (mod->attrs .defined ()) modified_mod = WithAttrs (modified_mod, mod->attrs ->dict );
924- return modified_mod;
901+ return mutator.builder_ ->Finalize ();
925902 }
926903
927904 private:
928- explicit TIRFuseMutator (const IRModule& mod) : mod_ (mod) {}
905+ explicit TIRFuseMutator (const IRModule& mod) : ExprMutator(mod), orig_mod_ (mod) {}
929906
930907 using ExprMutator::VisitExpr_;
931908
@@ -950,63 +927,67 @@ class TIRFuseMutator : public ExprMutator {
950927
951928 if (call->op ->IsInstance <GlobalVarNode>()) {
952929 // Case 1. It is a relax cross function call
953- GlobalVar old_gv = Downcast<GlobalVar>(call->op );
954- auto relax_func = Downcast<Function>(mod_->Lookup (old_gv));
955- auto it = fused_tir_funcs_.find (old_gv);
956- if (it != fused_tir_funcs_.end ()) {
957- const tir::PrimFunc& fused_tir = (*it).second ;
958- // Case 1.1. It calls a primitive relax function, update the call into a call_tir
959- GlobalVar fused_tir_gv = this ->builder_ ->AddFunction (fused_tir, old_gv->name_hint );
960- // Step a. Flatten all args since call_tir does not support Tuple value.
961- Array<Expr> arg_list;
962- Array<PrimExpr> tir_vars;
963- for (size_t i = 0 ; i < call->args .size (); ++i) {
964- auto arg = call->args [i];
965- auto sinfo = GetStructInfo (arg);
966-
967- ICHECK (!relax_func->params [i]->struct_info_ ->IsInstance <TupleStructInfoNode>() &&
968- !sinfo.as <TupleStructInfoNode>())
969- << " InternalError: "
970- << " All tuple parameters should be expanded before this point in FuseTIR. "
971- << " However, argument " << arg << " with struct info " << arg->struct_info_
972- << " is passed as argument " << i << " to Primitive Relax function " << old_gv
973- << " , which expects parameter " << relax_func->params [i] << " to have struct info "
974- << relax_func->params [i]->struct_info_ ;
975-
976- if (const auto * shape = sinfo.as <ShapeStructInfoNode>()) {
977- CHECK (shape->values .defined ())
978- << " FuseTIR requires all shape input has struct_info value." ;
979- for (const PrimExpr& prim_value : shape->values .value ()) {
980- CHECK (prim_value->IsInstance <tir::VarNode>())
981- << " All shape inputs are expected to be single tir var." ;
982- tir_vars.push_back (prim_value);
983- }
984- } else if (const auto * prim_value = sinfo.as <PrimStructInfoNode>()) {
985- CHECK (prim_value->value .defined ())
986- << " FuseTIR requires all R.Prim arguments to have a known value." ;
987- PrimExpr expr = prim_value->value .value ();
988- CHECK (expr->IsInstance <tir::VarNode>())
989- << " FuseTIR currently requires all R.Prim arguments to provide a single tir::Var." ;
990- tir_vars.push_back (expr);
991-
992- } else {
993- arg_list.push_back (arg);
930+ GlobalVar callee_gvar = Downcast<GlobalVar>(call->op );
931+ auto relax_func = Downcast<Function>(orig_mod_->Lookup (callee_gvar));
932+
933+ if (!relax_func->HasNonzeroAttr (attr::kPrimitive )) {
934+ // The callee is not a primitive function, no need to fuse.
935+ return call;
936+ }
937+
938+ tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR (orig_mod_, callee_gvar);
939+
940+ // Case 1.1. It calls a primitive relax function, update the call into a call_tir
941+ // GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint);
942+ this ->builder_ ->UpdateFunction (callee_gvar, fused_tir);
943+
944+ // Step a. Flatten all args since call_tir does not support Tuple value.
945+ Array<Expr> arg_list;
946+ Array<PrimExpr> tir_vars;
947+ for (size_t i = 0 ; i < call->args .size (); ++i) {
948+ auto arg = call->args [i];
949+ auto sinfo = GetStructInfo (arg);
950+
951+ ICHECK (!relax_func->params [i]->struct_info_ ->IsInstance <TupleStructInfoNode>() &&
952+ !sinfo.as <TupleStructInfoNode>())
953+ << " InternalError: "
954+ << " All tuple parameters should be expanded before this point in FuseTIR. "
955+ << " However, argument " << arg << " with struct info " << arg->struct_info_
956+ << " is passed as argument " << i << " to Primitive Relax function " << callee_gvar
957+ << " , which expects parameter " << relax_func->params [i] << " to have struct info "
958+ << relax_func->params [i]->struct_info_ ;
959+
960+ if (const auto * shape = sinfo.as <ShapeStructInfoNode>()) {
961+ CHECK (shape->values .defined ())
962+ << " FuseTIR requires all shape input has struct_info value." ;
963+ for (const PrimExpr& prim_value : shape->values .value ()) {
964+ CHECK (prim_value->IsInstance <tir::VarNode>())
965+ << " All shape inputs are expected to be single tir var." ;
966+ tir_vars.push_back (prim_value);
994967 }
968+ } else if (const auto * prim_value = sinfo.as <PrimStructInfoNode>()) {
969+ CHECK (prim_value->value .defined ())
970+ << " FuseTIR requires all R.Prim arguments to have a known value." ;
971+ PrimExpr expr = prim_value->value .value ();
972+ CHECK (expr->IsInstance <tir::VarNode>())
973+ << " FuseTIR currently requires all R.Prim arguments to provide a single tir::Var." ;
974+ tir_vars.push_back (expr);
975+
976+ } else {
977+ arg_list.push_back (arg);
995978 }
996- // Step b. Create call_tir
997- Array<Expr> call_args = {fused_tir_gv, Tuple (arg_list)};
998- if (!tir_vars.empty ()) {
999- call_args.push_back (ShapeExpr (tir_vars));
1000- }
1001- return Call (call_tir_op_, call_args, call->attrs , {GetStructInfo (call)});
1002- } else {
1003- // Case 1.2. The callee function is not primitive, nothing to do.
1004- return call;
1005979 }
980+ // Step b. Create call_tir
981+ Array<Expr> call_args = {callee_gvar, Tuple (arg_list)};
982+ if (!tir_vars.empty ()) {
983+ call_args.push_back (ShapeExpr (tir_vars));
984+ }
985+ return Call (call_tir_op_, call_args, call->attrs , {GetStructInfo (call)});
986+
1006987 } else if (call->op == call_tir_op_) {
1007988 // Case 2. It is a call_tir, re-emit the PrimFunc.
1008989 if (const auto * gv = call->args [0 ].as <GlobalVarNode>()) {
1009- tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_ ->Lookup (GetRef<GlobalVar>(gv)));
990+ tir::PrimFunc func = Downcast<tir::PrimFunc>(orig_mod_ ->Lookup (GetRef<GlobalVar>(gv)));
1010991 GlobalVar new_gv = this ->builder_ ->AddFunction (func, gv->name_hint );
1011992 Array<Expr> new_args = call->args ;
1012993 new_args.Set (0 , new_gv);
@@ -1020,9 +1001,7 @@ class TIRFuseMutator : public ExprMutator {
10201001
10211002 private:
10221003 /* ! \brief The IRModule */
1023- const IRModule& mod_;
1024- /* ! \brief The map from global var of primitive relax function to generated prim func. */
1025- Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
1004+ const IRModule& orig_mod_;
10261005};
10271006
10281007IRModule FuseTIR (IRModule mod) {
@@ -1044,6 +1023,15 @@ Pass FuseTIR() {
10441023 ExpandTupleArguments (),
10451024 RemoveUnusedParameters (),
10461025 inner_pass,
1026+ // At each call site, TIRFuseMutator inserts calls to
1027+ // attr::kPrimitive relax functions with calls to the fused
1028+ // TIR functions. However, until we reach the end of the
1029+ // module, we don't know whether the unfused TIR functions
1030+ // were called from contexts other than the attr::kPrimitive
1031+ // relax function, and cannot remove them entirely.
1032+ // Post-processing with `DeadCodeElimination` ensures that
1033+ // we do not keep any unreachable pre-fused TIR PrimFuncs.
1034+ DeadCodeElimination ({}),
10471035 },
10481036 " FuseTIR" );
10491037}
0 commit comments