Skip to content

Commit b000893

Browse files
author
Ashutosh Parkhi
committed
[CMSIS-NN] Convert scalar constants to tensor constants
Change-Id: I9ea9c28b1410b4a80a9235af2e84bc80b4dc3a66
1 parent 31de5bc commit b000893

5 files changed

Lines changed: 477 additions & 71 deletions

File tree

python/tvm/relay/op/contrib/cmsisnn.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def partition_for_cmsisnn(mod, params=None, **opts):
5656
transform.MergeComposite(pattern_table()),
5757
transform.AnnotateTarget("cmsis-nn"),
5858
transform.PartitionGraph(),
59+
transform.InferType(),
5960
GenerateCMSISNNConstants(),
61+
ScalarToTensorConstants(),
62+
transform.InferType(),
6063
ExtractConstantsFromPartitionedFunction(),
6164
transform.InferType(),
6265
]
@@ -223,11 +226,25 @@ def binary_op_pattern(op):
223226
is_constant(),
224227
)
225228

226-
def check_qnn_binary_op(extract):
229+
def check_qnn_binary_op(pattern):
227230
"""Check if multiply is supported by CMSIS-NN."""
231+
import numpy as np
232+
233+
arg0 = pattern.args[0]
234+
arg1 = pattern.args[1]
235+
both_args_scalar = False
236+
if (
237+
isinstance(arg0, tvm.relay.expr.Constant)
238+
and len(arg0.checked_type.shape) == 0
239+
and isinstance(arg1, tvm.relay.expr.Constant)
240+
and len(arg1.checked_type.shape) == 0
241+
):
242+
both_args_scalar = True
243+
228244
return (
229-
extract.args[0].checked_type.dtype == "int8"
230-
and extract.args[1].checked_type.dtype == "int8"
245+
arg0.checked_type.dtype == "int8"
246+
and arg1.checked_type.dtype == "int8"
247+
and not both_args_scalar
231248
)
232249

233250
return [

src/relay/backend/contrib/cmsisnn/extract_constants.cc

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,87 @@ class ExtractConstantsMutator : public MixedModeMutator {
6262
return func;
6363
}
6464

65-
function_to_constants_.Set(func, Array<Constant>{});
65+
function_to_arguments_.Set(func, Array<Expr>{});
6666
functions_.push_back(func);
6767
auto new_body = VisitExpr(func->body);
6868
functions_.pop_back();
69-
if (function_to_constants_[func].size()) {
69+
if (function_to_arguments_[func].size()) {
7070
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
7171
func->attrs);
7272
}
7373
return std::move(func);
7474
}
7575

76+
// Creates new arguments from current call's arguments
77+
// Updates constants into the caller arguments: here caller signifies caller that comprises call
78+
// to func
79+
Array<Expr> CreateNewCallArgsFromKickedoutConstants(Call call, Function func) {
80+
ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end());
81+
Array<Expr> fSignature(function_to_arguments_[func]);
82+
83+
// Is func a global_function?
84+
// main() is not registered for kicking out constants
85+
bool is_global_function = functions_.empty() ? true : false;
86+
87+
bool new_constants_added = false;
88+
// This tracks arguments traversed inside fSignature
89+
uint32_t fsignature_id = 0;
90+
// This contains arguments including constants for the caller of this function inside which
91+
// post_call resides.
92+
Array<Expr> new_caller_args;
93+
// New arguments to post_call that includes new variables representing constants kicked out of
94+
// the function
95+
Array<Expr> new_call_args;
96+
for (auto& arg : call->args) {
97+
if (auto* constant = arg.as<ConstantNode>()) {
98+
new_caller_args.push_back(arg);
99+
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
100+
++fsignature_id;
101+
new_constants_added = true;
102+
continue;
103+
}
104+
105+
// Push all constants from the fSignature until a variable corresponding to current argument
106+
// is hit
107+
while (fsignature_id < fSignature.size()) {
108+
auto* constant = fSignature[fsignature_id].as<ConstantNode>();
109+
if (constant == nullptr) {
110+
break;
111+
}
112+
new_caller_args.push_back(fSignature[fsignature_id++]);
113+
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
114+
new_constants_added = true;
115+
}
116+
117+
new_call_args.push_back(arg);
118+
if (is_global_function || arg.as<VarNode>()) {
119+
new_caller_args.push_back(arg);
120+
}
121+
++fsignature_id;
122+
}
123+
124+
// Push remaining constants as new arguments
125+
for (uint32_t i = fsignature_id; i < fSignature.size(); ++i) {
126+
auto* constant = fSignature[i].as<ConstantNode>();
127+
ICHECK(constant)
128+
<< "Rest of the collected arguments should be constant in the partitioned function.";
129+
new_caller_args.push_back(GetRef<Constant>(constant));
130+
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
131+
new_constants_added = true;
132+
}
133+
134+
// Update the arguments of caller of local function
135+
if (new_constants_added && !is_global_function) {
136+
const Function& last_func = functions_.back();
137+
Array<Expr> fconstants(function_to_arguments_[last_func]);
138+
function_to_arguments_.Set(last_func, tvm::runtime::Concat(fconstants, new_caller_args));
139+
} else {
140+
new_call_args = new_caller_args;
141+
}
142+
143+
return new_call_args;
144+
}
145+
76146
Expr Rewrite_(const CallNode* call, const Expr& post) final {
77147
Expr final_call = post;
78148
auto* post_call = post.as<CallNode>();
@@ -81,19 +151,24 @@ class ExtractConstantsMutator : public MixedModeMutator {
81151
// Perform this for non-main Call Nodes only
82152
if (!functions_.empty() && call->op.as<OpNode>()) {
83153
Array<Expr> new_args;
154+
const Function& last_func = functions_.back();
155+
Array<Expr> fSignature(function_to_arguments_[last_func]);
84156
for (auto& arg : post_call->args) {
157+
// Push all arguments including constants to maintain correct order of
158+
// variables and constants
85159
auto* const_arg = arg.as<ConstantNode>();
86160
if (const_arg && !const_arg->is_scalar()) {
87161
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
88162
new_args.push_back(var_arg);
89-
const Function& last_func = functions_.back();
90-
Array<Constant> fconstants(function_to_constants_[last_func]);
91-
fconstants.push_back(GetRef<Constant>(const_arg));
92-
function_to_constants_.Set(last_func, fconstants);
163+
fSignature.push_back(arg);
93164
} else {
165+
if (arg.as<VarNode>()) {
166+
fSignature.push_back(arg);
167+
}
94168
new_args.push_back(arg);
95169
}
96170
}
171+
function_to_arguments_.Set(last_func, fSignature);
97172
final_call = Call(call->op, new_args, call->attrs, {});
98173
}
99174

@@ -105,34 +180,18 @@ class ExtractConstantsMutator : public MixedModeMutator {
105180
auto new_glob_func = VisitExpr(glob_func);
106181
if (!new_glob_func.same_as(glob_func)) {
107182
mod_->Update(glob_var, Downcast<Function>(new_glob_func));
108-
Array<Expr> new_args = post_call->args;
109-
ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
110-
for (auto constant : function_to_constants_.at(glob_func)) {
111-
new_args.push_back(constant);
112-
}
183+
auto new_args = CreateNewCallArgsFromKickedoutConstants(GetRef<Call>(post_call), glob_func);
113184
final_call = Call(glob_var, new_args);
114185
}
115186
}
116187

117188
// Since the constants are kicked out of the local partitioned functions
118189
// a new call to local function is needed
119-
// Also, pass on the constants to the callee of this function to support nested functions
120190
if (auto* func_node = call->op.as<FunctionNode>()) {
121191
Function func = GetRef<Function>(func_node);
122192
auto new_func = VisitExpr(func);
123-
if (!new_func.same_as(func)) {
124-
Array<Expr> new_args = post_call->args;
125-
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
126-
const Function& last_func = functions_.back();
127-
Array<Constant> fconstants(function_to_constants_[last_func]);
128-
for (auto constant : function_to_constants_.at(func)) {
129-
fconstants.push_back(constant);
130-
Var var_arg = Var(gen_var_name(), constant->tensor_type());
131-
new_args.push_back(var_arg);
132-
}
133-
function_to_constants_.Set(last_func, fconstants);
134-
final_call = Call(new_func, new_args);
135-
}
193+
Array<Expr> new_args = CreateNewCallArgsFromKickedoutConstants(GetRef<Call>(post_call), func);
194+
final_call = Call(new_func, new_args);
136195
}
137196

138197
return final_call;
@@ -141,9 +200,10 @@ class ExtractConstantsMutator : public MixedModeMutator {
141200
private:
142201
/* \brief Updated module where all calls have replaced constants with new variables */
143202
IRModule mod_;
144-
/* \brief Maintains mapping of original function to the replaced constants */
145-
Map<Function, Array<Constant>> function_to_constants_;
146-
/* \brief Stack of functions to determine scope while filling up function_to_constants_ */
203+
/* \brief Maintains mapping of original function to the replaced constants along with other
204+
* arguments to retain the order in which variables are used within the function */
205+
Map<Function, Array<Expr>> function_to_arguments_;
206+
/* \brief Stack of functions to determine scope while filling up function_to_arguments_ */
147207
Array<Function> functions_;
148208
/* \brief Keeps track of variables being created */
149209
int var_count_ = 0;

0 commit comments

Comments
 (0)