Skip to content

Commit db277fd

Browse files
committed
[Relay][Pass] CanonicalizeExpr
1 parent 1f4ec9e commit db277fd

File tree

4 files changed

+179
-16
lines changed

4 files changed

+179
-16
lines changed

python/tvm/relay/ir_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,6 @@ def partial_evaluate(expr):
652652
The output expression.
653653
"""
654654
return _ir_pass.partial_evaluate(expr)
655+
656+
def canonicalize_expr(expr):
657+
return _ir_pass.canonicalize_expr(expr)

src/relay/backend/build_module.cc

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ const std::unordered_map<std::string, int> OptPassLevel::_data = {
6969
{"FoldScaleAxis", 3},
7070
{"AlterOpLayout", 3},
7171
{"CanonicalizeOps", 3},
72-
{"EliminateCommonSubexpr", 3}
72+
{"EliminateCommonSubexpr", 3},
73+
{"CanonicalizeExpr", 3}
7374
};
7475

7576
/*!
@@ -405,22 +406,8 @@ class RelayBuildModule : public runtime::ModuleNode {
405406
func = CallPackedFunc("relay._ir_pass.simplify_inference", func);
406407
}
407408
if (cfg.pass_enabled("EliminateCommonSubexpr")) {
408-
auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
409-
Expr expr = args[0];
410-
if (expr.as<CallNode>()) {
411-
auto call_node = expr.as<CallNode>();
412-
auto op_node = call_node->op.as<OpNode>();
413-
if (op_node->name == "cast") {
414-
auto attrs = call_node->attrs.as<CastAttrs>();
415-
if (attrs->dtype == HalideIR::Int(32)) {
416-
*rv = true;
417-
}
418-
}
419-
}
420-
*rv = false;
421-
});
422409
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
423-
func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip);
410+
func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, nullptr);
424411
}
425412
if (cfg.pass_enabled("CombineParallelConv2D")) {
426413
const int min_num_branches = 3;
@@ -437,6 +424,10 @@ class RelayBuildModule : public runtime::ModuleNode {
437424
func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func);
438425
func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
439426
}
427+
if (cfg.pass_enabled("CanonicalizeExpr")) {
428+
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
429+
func = CallPackedFunc("relay._ir_pass.canonicalize_expr", func);
430+
}
440431
if (cfg.pass_enabled("CanonicalizeOps")) {
441432
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
442433
func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func);
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file canonicalize_expr.cc
23+
* \brief Canonicalize an expression to make operator fusion more efficient.
24+
*/
25+
#include <tvm/relay/pass.h>
26+
#include <tvm/relay/expr_functor.h>
27+
#include <tvm/relay/attrs/nn.h>
28+
#include "pattern_util.h"
29+
#include "pass_util.h"
30+
31+
namespace tvm {
32+
namespace relay {
33+
34+
// This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a
35+
// copy of it in each branch such that after fusion the previous function have output with fewer
36+
// bits.
37+
class ExprCanonicalizer : public ExprMutator {
38+
public:
39+
Expr VisitExpr_(const CallNode* call) {
40+
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
41+
42+
if (const OpNode* opnode = call->op.as<OpNode>()) {
43+
auto pattern = fpattern[GetRef<Op>(opnode)];
44+
if (pattern <= kBroadcast) {
45+
Array<Expr> call_args = call->args;
46+
bool unchanged = true;
47+
for (size_t i = 0; i < call_args.size(); ++i) {
48+
Expr arg = call_args[i];
49+
Expr new_arg = GetNewCallArg(arg);
50+
if (!arg.same_as(new_arg)) {
51+
call_args.Set(i, new_arg);
52+
unchanged = false;
53+
}
54+
}
55+
if (unchanged) {
56+
return GetRef<Expr>(call);
57+
}
58+
return CallNode::make(call->op, call_args, call->attrs, call->type_args);
59+
}
60+
}
61+
62+
Expr new_expr = ExprMutator::VisitExpr_(call);
63+
return new_expr;
64+
}
65+
66+
private:
67+
std::unordered_map<const Node*, size_t> ref_counter_;
68+
69+
Expr GetNewCallArg(const Expr& e) {
70+
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
71+
72+
static auto& cast = Op::Get("cast");
73+
Expr new_expr = this->VisitExpr(e);
74+
75+
if (const CallNode* call = e.as<CallNode>()) {
76+
if (call->op.same_as(cast)) {
77+
auto attrs = call->attrs.as<CastAttrs>();
78+
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
79+
CHECK(from_type);
80+
81+
if (from_type->dtype.bits() < attrs->dtype.bits()) {
82+
if (++ref_counter_[call] > 1) {
83+
const CallNode* new_call = new_expr.as<CallNode>();
84+
CHECK(new_call);
85+
CHECK(new_call->op.same_as(cast));
86+
return CallNode::make(new_call->op, new_call->args, new_call->attrs,
87+
new_call->type_args);
88+
}
89+
}
90+
}
91+
}
92+
return new_expr;
93+
}
94+
};
95+
96+
Expr CanonicalizeExpr(const Expr& e) {
97+
return ExprCanonicalizer().Mutate(e);
98+
}
99+
100+
TVM_REGISTER_API("relay._ir_pass.canonicalize_expr")
101+
.set_body_typed(CanonicalizeExpr);
102+
103+
} // namespace relay
104+
} // namespace tvm
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
import tvm.relay as relay
20+
21+
22+
def test_canonicalize_cast():
23+
def before(data, conv_weight, bias1, bias2):
24+
x = relay.nn.conv2d(data, conv_weight,
25+
channels=16,
26+
kernel_size=(3, 3),
27+
padding=(1, 1),
28+
out_dtype="int8")
29+
x1 = relay.cast(x, dtype="int32")
30+
y1 = relay.add(x1, bias1)
31+
y2 = relay.add(x1, bias2)
32+
y = relay.add(y1, y2)
33+
return relay.Function([data, conv_weight, bias1, bias2], y)
34+
35+
def expected(data, conv_weight, bias1, bias2):
36+
x = relay.nn.conv2d(data, conv_weight,
37+
channels=16,
38+
kernel_size=(3, 3),
39+
padding=(1, 1),
40+
out_dtype="int8")
41+
x1 = relay.cast(x, dtype="int32")
42+
x2 = relay.cast(x, dtype="int32")
43+
y1 = relay.add(x1, bias1)
44+
y2 = relay.add(x2, bias2)
45+
y = relay.add(y1, y2)
46+
return relay.Function([data, conv_weight, bias1, bias2], y)
47+
48+
def check(shape):
49+
data = relay.var("data", shape=shape, dtype="int8")
50+
conv_weight = relay.var("weight")
51+
bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32")
52+
bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
53+
y = before(data, conv_weight, bias1, bias2)
54+
y = relay.ir_pass.infer_type(y)
55+
y = relay.ir_pass.canonicalize_expr(y)
56+
y = relay.ir_pass.infer_type(y)
57+
y_expected = expected(data, conv_weight, bias1, bias2)
58+
y_expected = relay.ir_pass.infer_type(y_expected)
59+
assert relay.ir_pass.alpha_equal(y, y_expected)
60+
61+
check((1, 16, 7, 7))
62+
63+
64+
if __name__ == '__main__':
65+
test_canonicalize_cast()

0 commit comments

Comments
 (0)