@@ -62,17 +62,87 @@ class ExtractConstantsMutator : public MixedModeMutator {
6262 return func;
6363 }
6464
65- function_to_constants_ .Set (func, Array<Constant >{});
65+ function_to_arguments_ .Set (func, Array<Expr >{});
6666 functions_.push_back (func);
6767 auto new_body = VisitExpr (func->body );
6868 functions_.pop_back ();
69- if (function_to_constants_ [func].size ()) {
69+ if (function_to_arguments_ [func].size ()) {
7070 func = Function (FreeVars (new_body), new_body, func->ret_type , FreeTypeVars (new_body, mod_),
7171 func->attrs );
7272 }
7373 return std::move (func);
7474 }
7575
76+ // Creates new arguments from current call's arguments
77+ // Updates constants into the caller arguments: here caller signifies caller that comprises call
78+ // to func
79+ Array<Expr> CreateNewCallArgsFromKickedoutConstants (Call call, Function func) {
80+ ICHECK (function_to_arguments_.find (func) != function_to_arguments_.end ());
81+ Array<Expr> fSignature (function_to_arguments_[func]);
82+
83+ // Is func a global_function?
84+ // main() is not registered for kicking out constants
85+ bool is_global_function = functions_.empty () ? true : false ;
86+
87+ bool new_constants_added = false ;
88+ // This tracks arguments traversed inside fSignature
89+ uint32_t fsignature_id = 0 ;
90+ // This contains arguments including constants for the caller of this function inside which
91+ // post_call resides.
92+ Array<Expr> new_caller_args;
93+ // New arguments to post_call that includes new variables representing constants kicked out of
94+ // the function
95+ Array<Expr> new_call_args;
96+ for (auto & arg : call->args ) {
97+ if (auto * constant = arg.as <ConstantNode>()) {
98+ new_caller_args.push_back (arg);
99+ new_call_args.push_back (Var (gen_var_name (), constant->tensor_type ()));
100+ ++fsignature_id;
101+ new_constants_added = true ;
102+ continue ;
103+ }
104+
105+ // Push all constants from the fSignature until a variable corresponding to current argument
106+ // is hit
107+ while (fsignature_id < fSignature .size ()) {
108+ auto * constant = fSignature [fsignature_id].as <ConstantNode>();
109+ if (constant == nullptr ) {
110+ break ;
111+ }
112+ new_caller_args.push_back (fSignature [fsignature_id++]);
113+ new_call_args.push_back (Var (gen_var_name (), constant->tensor_type ()));
114+ new_constants_added = true ;
115+ }
116+
117+ new_call_args.push_back (arg);
118+ if (is_global_function || arg.as <VarNode>()) {
119+ new_caller_args.push_back (arg);
120+ }
121+ ++fsignature_id;
122+ }
123+
124+ // Push remaining constants as new arguments
125+ for (uint32_t i = fsignature_id; i < fSignature .size (); ++i) {
126+ auto * constant = fSignature [i].as <ConstantNode>();
127+ ICHECK (constant)
128+ << " Rest of the collected arguments should be constant in the partitioned function." ;
129+ new_caller_args.push_back (GetRef<Constant>(constant));
130+ new_call_args.push_back (Var (gen_var_name (), constant->tensor_type ()));
131+ new_constants_added = true ;
132+ }
133+
134+ // Update the arguments of caller of local function
135+ if (new_constants_added && !is_global_function) {
136+ const Function& last_func = functions_.back ();
137+ Array<Expr> fconstants (function_to_arguments_[last_func]);
138+ function_to_arguments_.Set (last_func, tvm::runtime::Concat (fconstants, new_caller_args));
139+ } else {
140+ new_call_args = new_caller_args;
141+ }
142+
143+ return new_call_args;
144+ }
145+
76146 Expr Rewrite_ (const CallNode* call, const Expr& post ) final {
77147 Expr final_call = post ;
78148 auto * post_call = post .as <CallNode>();
@@ -81,19 +151,24 @@ class ExtractConstantsMutator : public MixedModeMutator {
81151 // Perform this for non-main Call Nodes only
82152 if (!functions_.empty () && call->op .as <OpNode>()) {
83153 Array<Expr> new_args;
154+ const Function& last_func = functions_.back ();
155+ Array<Expr> fSignature (function_to_arguments_[last_func]);
84156 for (auto & arg : post_call->args ) {
157+ // Push all arguments including constants to maintain correct order of
158+ // variables and constants
85159 auto * const_arg = arg.as <ConstantNode>();
86160 if (const_arg && !const_arg->is_scalar ()) {
87161 Var var_arg = Var (gen_var_name (), const_arg->tensor_type ());
88162 new_args.push_back (var_arg);
89- const Function& last_func = functions_.back ();
90- Array<Constant> fconstants (function_to_constants_[last_func]);
91- fconstants.push_back (GetRef<Constant>(const_arg));
92- function_to_constants_.Set (last_func, fconstants);
163+ fSignature .push_back (arg);
93164 } else {
165+ if (arg.as <VarNode>()) {
166+ fSignature .push_back (arg);
167+ }
94168 new_args.push_back (arg);
95169 }
96170 }
171+ function_to_arguments_.Set (last_func, fSignature );
97172 final_call = Call (call->op , new_args, call->attrs , {});
98173 }
99174
@@ -105,34 +180,18 @@ class ExtractConstantsMutator : public MixedModeMutator {
105180 auto new_glob_func = VisitExpr (glob_func);
106181 if (!new_glob_func.same_as (glob_func)) {
107182 mod_->Update (glob_var, Downcast<Function>(new_glob_func));
108- Array<Expr> new_args = post_call->args ;
109- ICHECK (function_to_constants_.find (glob_func) != function_to_constants_.end ());
110- for (auto constant : function_to_constants_.at (glob_func)) {
111- new_args.push_back (constant);
112- }
183+ auto new_args = CreateNewCallArgsFromKickedoutConstants (GetRef<Call>(post_call), glob_func);
113184 final_call = Call (glob_var, new_args);
114185 }
115186 }
116187
117188 // Since the constants are kicked out of the local partitioned functions
118189 // a new call to local function is needed
119- // Also, pass on the constants to the callee of this function to support nested functions
120190 if (auto * func_node = call->op .as <FunctionNode>()) {
121191 Function func = GetRef<Function>(func_node);
122192 auto new_func = VisitExpr (func);
123- if (!new_func.same_as (func)) {
124- Array<Expr> new_args = post_call->args ;
125- ICHECK (function_to_constants_.find (func) != function_to_constants_.end ());
126- const Function& last_func = functions_.back ();
127- Array<Constant> fconstants (function_to_constants_[last_func]);
128- for (auto constant : function_to_constants_.at (func)) {
129- fconstants.push_back (constant);
130- Var var_arg = Var (gen_var_name (), constant->tensor_type ());
131- new_args.push_back (var_arg);
132- }
133- function_to_constants_.Set (last_func, fconstants);
134- final_call = Call (new_func, new_args);
135- }
193+ Array<Expr> new_args = CreateNewCallArgsFromKickedoutConstants (GetRef<Call>(post_call), func);
194+ final_call = Call (new_func, new_args);
136195 }
137196
138197 return final_call;
@@ -141,9 +200,10 @@ class ExtractConstantsMutator : public MixedModeMutator {
141200 private:
142201 /* \brief Updated module where all calls have replaced constants with new variables */
143202 IRModule mod_;
144- /* \brief Maintains mapping of original function to the replaced constants */
145- Map<Function, Array<Constant>> function_to_constants_;
146- /* \brief Stack of functions to determine scope while filling up function_to_constants_ */
203+ /* \brief Maintains mapping of original function to the replaced constants along with other
204+ * arguments to retain the order in which variables are used within the function */
205+ Map<Function, Array<Expr>> function_to_arguments_;
206+ /* \brief Stack of functions to determine scope while filling up function_to_arguments_ */
147207 Array<Function> functions_;
148208 /* \brief Keeps track of variables being created */
149209 int var_count_ = 0 ;
0 commit comments