Skip to content

Commit cb3d7e2

Browse files
authored
[CMSIS-NN] Convert scalar constants to tensor constants (#10100)
1 parent 565e6b4 commit cb3d7e2

File tree

6 files changed

+671
-94
lines changed

6 files changed

+671
-94
lines changed

python/tvm/relay/op/contrib/cmsisnn.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def partition_for_cmsisnn(mod, params=None, **opts):
5757
transform.AnnotateTarget("cmsis-nn"),
5858
transform.PartitionGraph(),
5959
GenerateCMSISNNConstants(),
60+
ScalarToTensorConstants(),
6061
ExtractConstantsFromPartitionedFunction(),
6162
transform.InferType(),
6263
]
@@ -223,11 +224,23 @@ def binary_op_pattern(op):
223224
is_constant(),
224225
)
225226

226-
def check_qnn_binary_op(extract):
227+
def check_qnn_binary_op(pattern):
227228
"""Check if multiply is supported by CMSIS-NN."""
229+
arg0 = pattern.args[0]
230+
arg1 = pattern.args[1]
231+
both_args_scalar = False
232+
if (
233+
isinstance(arg0, tvm.relay.expr.Constant)
234+
and len(arg0.checked_type.shape) == 0
235+
and isinstance(arg1, tvm.relay.expr.Constant)
236+
and len(arg1.checked_type.shape) == 0
237+
):
238+
both_args_scalar = True
239+
228240
return (
229-
extract.args[0].checked_type.dtype == "int8"
230-
and extract.args[1].checked_type.dtype == "int8"
241+
arg0.checked_type.dtype == "int8"
242+
and arg1.checked_type.dtype == "int8"
243+
and not both_args_scalar
231244
)
232245

233246
return [

src/relay/backend/contrib/cmsisnn/extract_constants.cc

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,88 @@ class ExtractConstantsMutator : public MixedModeMutator {
6262
return func;
6363
}
6464

65-
function_to_constants_.Set(func, Array<Constant>{});
65+
function_to_arguments_.Set(func, Array<Expr>{});
6666
functions_.push_back(func);
6767
auto new_body = VisitExpr(func->body);
6868
functions_.pop_back();
69-
if (function_to_constants_[func].size()) {
69+
if (function_to_arguments_[func].size()) {
7070
func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
7171
FreeTypeVars(new_body, mod_), func->attrs);
7272
}
7373
return std::move(func);
7474
}
7575

76+
// Creates new arguments from current call's arguments
77+
// Updates constants into the caller arguments: here caller signifies caller that comprises call
78+
// to func
79+
Array<Expr> CreateNewCallArgsFromExtractedConstants(Call call, Function func) {
80+
ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end());
81+
Array<Expr> function_signature(function_to_arguments_[func]);
82+
83+
// Is func a global_function?
84+
// main() is not registered for extracting constants
85+
bool is_global_function = functions_.empty() ? true : false;
86+
87+
bool new_constants_added = false;
88+
// This tracks arguments traversed inside function_signature
89+
uint32_t function_signature_id = 0;
90+
// This contains arguments including constants for the caller of this function inside which
91+
// post_call resides.
92+
Array<Expr> new_caller_args;
93+
// New arguments to post_call that includes new variables representing constants extracted from
94+
// the function
95+
Array<Expr> new_call_args;
96+
for (auto& arg : call->args) {
97+
if (auto* constant = arg.as<ConstantNode>()) {
98+
new_caller_args.push_back(arg);
99+
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
100+
++function_signature_id;
101+
new_constants_added = true;
102+
continue;
103+
}
104+
105+
// Push all constants from the function_signature until a variable corresponding to the
106+
// current argument is hit
107+
while (function_signature_id < function_signature.size()) {
108+
auto* constant = function_signature[function_signature_id].as<ConstantNode>();
109+
if (constant == nullptr) {
110+
break;
111+
}
112+
new_caller_args.push_back(function_signature[function_signature_id++]);
113+
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
114+
new_constants_added = true;
115+
}
116+
117+
new_call_args.push_back(arg);
118+
if (is_global_function || arg.as<VarNode>()) {
119+
new_caller_args.push_back(arg);
120+
}
121+
++function_signature_id;
122+
}
123+
124+
// Push remaining constants as new arguments
125+
for (uint32_t i = function_signature_id; i < function_signature.size(); ++i) {
126+
auto* constant = function_signature[i].as<ConstantNode>();
127+
ICHECK(constant)
128+
<< "Rest of the collected arguments should be constant in the partitioned function.";
129+
new_caller_args.push_back(GetRef<Constant>(constant));
130+
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
131+
new_constants_added = true;
132+
}
133+
134+
// Update the arguments of caller of local function
135+
if (new_constants_added && !is_global_function) {
136+
const Function& last_func = functions_.back();
137+
Array<Expr> function_constants(function_to_arguments_[last_func]);
138+
function_to_arguments_.Set(last_func,
139+
tvm::runtime::Concat(function_constants, new_caller_args));
140+
} else {
141+
new_call_args = new_caller_args;
142+
}
143+
144+
return new_call_args;
145+
}
146+
76147
Expr Rewrite_(const CallNode* call, const Expr& post) final {
77148
Expr final_call = post;
78149
auto* post_call = post.as<CallNode>();
@@ -81,58 +152,47 @@ class ExtractConstantsMutator : public MixedModeMutator {
81152
// Perform this for non-main Call Nodes only
82153
if (!functions_.empty() && call->op.as<OpNode>()) {
83154
Array<Expr> new_args;
155+
const Function& last_func = functions_.back();
156+
Array<Expr> function_signature(function_to_arguments_[last_func]);
84157
for (auto& arg : post_call->args) {
158+
// Push all arguments including constants to maintain correct order of
159+
// variables and constants
85160
auto* const_arg = arg.as<ConstantNode>();
86161
if (const_arg && !const_arg->is_scalar()) {
87162
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
88163
new_args.push_back(var_arg);
89-
const Function& last_func = functions_.back();
90-
Array<Constant> fconstants(function_to_constants_[last_func]);
91-
fconstants.push_back(GetRef<Constant>(const_arg));
92-
function_to_constants_.Set(last_func, fconstants);
164+
function_signature.push_back(arg);
93165
} else {
166+
if (arg.as<VarNode>()) {
167+
function_signature.push_back(arg);
168+
}
94169
new_args.push_back(arg);
95170
}
96171
}
172+
function_to_arguments_.Set(last_func, function_signature);
97173
final_call = Call(call->op, new_args, call->attrs, {});
98174
}
99175

100-
// Since the constants are kicked out of partitioned functions
176+
// Since the constants are extracted from partitioned functions
101177
// a new call to global function is needed
102178
if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
103179
auto glob_var = GetRef<GlobalVar>(glob_var_node);
104180
auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
105181
auto new_glob_func = VisitExpr(glob_func);
106182
if (!new_glob_func.same_as(glob_func)) {
107183
mod_->Update(glob_var, Downcast<Function>(new_glob_func));
108-
Array<Expr> new_args = post_call->args;
109-
ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
110-
for (auto constant : function_to_constants_.at(glob_func)) {
111-
new_args.push_back(constant);
112-
}
184+
auto new_args = CreateNewCallArgsFromExtractedConstants(GetRef<Call>(post_call), glob_func);
113185
final_call = Call(glob_var, new_args);
114186
}
115187
}
116188

117-
// Since the constants are kicked out of the local partitioned functions
189+
// Since the constants are extracted from the local partitioned functions
118190
// a new call to local function is needed
119-
// Also, pass on the constants to the callee of this function to support nested functions
120191
if (auto* func_node = call->op.as<FunctionNode>()) {
121192
Function func = GetRef<Function>(func_node);
122193
auto new_func = VisitExpr(func);
123-
if (!new_func.same_as(func)) {
124-
Array<Expr> new_args = post_call->args;
125-
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
126-
const Function& last_func = functions_.back();
127-
Array<Constant> fconstants(function_to_constants_[last_func]);
128-
for (auto constant : function_to_constants_.at(func)) {
129-
fconstants.push_back(constant);
130-
Var var_arg = Var(gen_var_name(), constant->tensor_type());
131-
new_args.push_back(var_arg);
132-
}
133-
function_to_constants_.Set(last_func, fconstants);
134-
final_call = Call(new_func, new_args);
135-
}
194+
Array<Expr> new_args = CreateNewCallArgsFromExtractedConstants(GetRef<Call>(post_call), func);
195+
final_call = Call(new_func, new_args);
136196
}
137197

138198
return final_call;
@@ -141,15 +201,16 @@ class ExtractConstantsMutator : public MixedModeMutator {
141201
private:
142202
/* \brief Updated module where all calls have replaced constants with new variables */
143203
IRModule mod_;
144-
/* \brief Maintains mapping of original function to the replaced constants */
145-
Map<Function, Array<Constant>> function_to_constants_;
146-
/* \brief Stack of functions to determine scope while filling up function_to_constants_ */
204+
/* \brief Maintains mapping of original function to the replaced constants along with other
205+
* arguments to retain the order in which variables are used within the function */
206+
Map<Function, Array<Expr>> function_to_arguments_;
207+
/* \brief Stack of functions to determine scope while filling up function_to_arguments_ */
147208
Array<Function> functions_;
148209
/* \brief Keeps track of variables being created */
149210
int var_count_ = 0;
150211
};
151212

152-
/*! * \brief Kicks out all constants out of the partitioned function into main() */
213+
/*! * \brief Extracts all constants out of the partitioned function into main() */
153214
IRModule ExtractConstants(const IRModule& mod) {
154215
String func_name;
155216
Function func;
@@ -169,7 +230,7 @@ transform::Pass ExtractConstantsFromPartitionedFunction() {
169230
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
170231
[=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); };
171232
return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction",
172-
{});
233+
{"InferType"});
173234
}
174235

175236
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")

0 commit comments

Comments
 (0)