@@ -62,17 +62,88 @@ class ExtractConstantsMutator : public MixedModeMutator {
6262 return func;
6363 }
6464
65- function_to_constants_ .Set (func, Array<Constant >{});
65+ function_to_arguments_ .Set (func, Array<Expr >{});
6666 functions_.push_back (func);
6767 auto new_body = VisitExpr (func->body );
6868 functions_.pop_back ();
69- if (function_to_constants_ [func].size ()) {
69+ if (function_to_arguments_ [func].size ()) {
7070 func = WithFields (func, FreeVars (new_body), new_body, func->ret_type ,
7171 FreeTypeVars (new_body, mod_), func->attrs );
7272 }
7373 return std::move (func);
7474 }
7575
76+ // Creates new arguments from current call's arguments
77+ // Updates constants into the caller arguments: here caller signifies caller that comprises call
78+ // to func
79+ Array<Expr> CreateNewCallArgsFromExtractedConstants (Call call, Function func) {
80+ ICHECK (function_to_arguments_.find (func) != function_to_arguments_.end ());
81+ Array<Expr> function_signature (function_to_arguments_[func]);
82+
83+ // Is func a global_function?
84+ // main() is not registered for extracting constants
85+ bool is_global_function = functions_.empty () ? true : false ;
86+
87+ bool new_constants_added = false ;
88+ // This tracks arguments traversed inside function_signature
89+ uint32_t function_signature_id = 0 ;
90+ // This contains arguments including constants for the caller of this function inside which
91+ // post_call resides.
92+ Array<Expr> new_caller_args;
93+ // New arguments to post_call that includes new variables representing constants extracted from
94+ // the function
95+ Array<Expr> new_call_args;
96+ for (auto & arg : call->args ) {
97+ if (auto * constant = arg.as <ConstantNode>()) {
98+ new_caller_args.push_back (arg);
99+ new_call_args.push_back (Var (gen_var_name (), constant->tensor_type ()));
100+ ++function_signature_id;
101+ new_constants_added = true ;
102+ continue ;
103+ }
104+
105+ // Push all constants from the function_signature until a variable corresponding to the
106+ // current argument is hit
107+ while (function_signature_id < function_signature.size ()) {
108+ auto * constant = function_signature[function_signature_id].as <ConstantNode>();
109+ if (constant == nullptr ) {
110+ break ;
111+ }
112+ new_caller_args.push_back (function_signature[function_signature_id++]);
113+ new_call_args.push_back (Var (gen_var_name (), constant->tensor_type ()));
114+ new_constants_added = true ;
115+ }
116+
117+ new_call_args.push_back (arg);
118+ if (is_global_function || arg.as <VarNode>()) {
119+ new_caller_args.push_back (arg);
120+ }
121+ ++function_signature_id;
122+ }
123+
124+ // Push remaining constants as new arguments
125+ for (uint32_t i = function_signature_id; i < function_signature.size (); ++i) {
126+ auto * constant = function_signature[i].as <ConstantNode>();
127+ ICHECK (constant)
128+ << " Rest of the collected arguments should be constant in the partitioned function." ;
129+ new_caller_args.push_back (GetRef<Constant>(constant));
130+ new_call_args.push_back (Var (gen_var_name (), constant->tensor_type ()));
131+ new_constants_added = true ;
132+ }
133+
134+ // Update the arguments of caller of local function
135+ if (new_constants_added && !is_global_function) {
136+ const Function& last_func = functions_.back ();
137+ Array<Expr> function_constants (function_to_arguments_[last_func]);
138+ function_to_arguments_.Set (last_func,
139+ tvm::runtime::Concat (function_constants, new_caller_args));
140+ } else {
141+ new_call_args = new_caller_args;
142+ }
143+
144+ return new_call_args;
145+ }
146+
76147 Expr Rewrite_ (const CallNode* call, const Expr& post ) final {
77148 Expr final_call = post ;
78149 auto * post_call = post .as <CallNode>();
@@ -81,58 +152,47 @@ class ExtractConstantsMutator : public MixedModeMutator {
81152 // Perform this for non-main Call Nodes only
82153 if (!functions_.empty () && call->op .as <OpNode>()) {
83154 Array<Expr> new_args;
155+ const Function& last_func = functions_.back ();
156+ Array<Expr> function_signature (function_to_arguments_[last_func]);
84157 for (auto & arg : post_call->args ) {
158+ // Push all arguments including constants to maintain correct order of
159+ // variables and constants
85160 auto * const_arg = arg.as <ConstantNode>();
86161 if (const_arg && !const_arg->is_scalar ()) {
87162 Var var_arg = Var (gen_var_name (), const_arg->tensor_type ());
88163 new_args.push_back (var_arg);
89- const Function& last_func = functions_.back ();
90- Array<Constant> fconstants (function_to_constants_[last_func]);
91- fconstants.push_back (GetRef<Constant>(const_arg));
92- function_to_constants_.Set (last_func, fconstants);
164+ function_signature.push_back (arg);
93165 } else {
166+ if (arg.as <VarNode>()) {
167+ function_signature.push_back (arg);
168+ }
94169 new_args.push_back (arg);
95170 }
96171 }
172+ function_to_arguments_.Set (last_func, function_signature);
97173 final_call = Call (call->op , new_args, call->attrs , {});
98174 }
99175
100- // Since the constants are kicked out of partitioned functions
176+ // Since the constants are extracted from partitioned functions
101177 // a new call to global function is needed
102178 if (auto * glob_var_node = post_call->op .as <GlobalVarNode>()) {
103179 auto glob_var = GetRef<GlobalVar>(glob_var_node);
104180 auto glob_func = Downcast<Function>(mod_->Lookup (glob_var));
105181 auto new_glob_func = VisitExpr (glob_func);
106182 if (!new_glob_func.same_as (glob_func)) {
107183 mod_->Update (glob_var, Downcast<Function>(new_glob_func));
108- Array<Expr> new_args = post_call->args ;
109- ICHECK (function_to_constants_.find (glob_func) != function_to_constants_.end ());
110- for (auto constant : function_to_constants_.at (glob_func)) {
111- new_args.push_back (constant);
112- }
184+ auto new_args = CreateNewCallArgsFromExtractedConstants (GetRef<Call>(post_call), glob_func);
113185 final_call = Call (glob_var, new_args);
114186 }
115187 }
116188
117- // Since the constants are kicked out of the local partitioned functions
189+ // Since the constants are extracted from the local partitioned functions
118190 // a new call to local function is needed
119- // Also, pass on the constants to the callee of this function to support nested functions
120191 if (auto * func_node = call->op .as <FunctionNode>()) {
121192 Function func = GetRef<Function>(func_node);
122193 auto new_func = VisitExpr (func);
123- if (!new_func.same_as (func)) {
124- Array<Expr> new_args = post_call->args ;
125- ICHECK (function_to_constants_.find (func) != function_to_constants_.end ());
126- const Function& last_func = functions_.back ();
127- Array<Constant> fconstants (function_to_constants_[last_func]);
128- for (auto constant : function_to_constants_.at (func)) {
129- fconstants.push_back (constant);
130- Var var_arg = Var (gen_var_name (), constant->tensor_type ());
131- new_args.push_back (var_arg);
132- }
133- function_to_constants_.Set (last_func, fconstants);
134- final_call = Call (new_func, new_args);
135- }
194+ Array<Expr> new_args = CreateNewCallArgsFromExtractedConstants (GetRef<Call>(post_call), func);
195+ final_call = Call (new_func, new_args);
136196 }
137197
138198 return final_call;
@@ -141,15 +201,16 @@ class ExtractConstantsMutator : public MixedModeMutator {
141201 private:
142202 /* \brief Updated module where all calls have replaced constants with new variables */
143203 IRModule mod_;
144- /* \brief Maintains mapping of original function to the replaced constants */
145- Map<Function, Array<Constant>> function_to_constants_;
146- /* \brief Stack of functions to determine scope while filling up function_to_constants_ */
204+ /* \brief Maintains mapping of original function to the replaced constants along with other
205+ * arguments to retain the order in which variables are used within the function */
206+ Map<Function, Array<Expr>> function_to_arguments_;
207+ /* \brief Stack of functions to determine scope while filling up function_to_arguments_ */
147208 Array<Function> functions_;
148209 /* \brief Keeps track of variables being created */
149210 int var_count_ = 0 ;
150211};
151212
152- /* ! * \brief Kicks out all constants out of the partitioned function into main() */
213+ /* ! * \brief Extracts all constants out of the partitioned function into main() */
153214IRModule ExtractConstants (const IRModule& mod) {
154215 String func_name;
155216 Function func;
@@ -169,7 +230,7 @@ transform::Pass ExtractConstantsFromPartitionedFunction() {
169230 runtime::TypedPackedFunc<IRModule (IRModule, transform::PassContext)> pass_func =
170231 [=](IRModule m, transform::PassContext pc) { return ExtractConstants (m); };
171232 return tvm::transform::CreateModulePass (pass_func, 0 , " ExtractConstantsFromPartitionedFunction" ,
172- {});
233+ {" InferType " });
173234}
174235
175236TVM_REGISTER_GLOBAL (" relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction" )
0 commit comments