Skip to content

Commit fb2cd3f

Browse files
vinx13Wei Chen
authored andcommitted
[RELAY][PASS] Common subexpression elimination (apache#2639)
1 parent b095d70 commit fb2cd3f

4 files changed

Lines changed: 170 additions & 0 deletions

File tree

python/tvm/relay/ir_pass.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,3 +564,23 @@ def get_total_mac_number(expr):
564564
The number of MACs (multiply-accumulate) of a model
565565
"""
566566
return _ir_pass.GetTotalMacNumber(expr)
567+
568+
569+
def eliminate_common_subexpr(expr, fskip=None):
570+
"""
571+
Eliminate common subexpressions.
572+
573+
Parameters
574+
----------
575+
expr : tvm.relay.Expr
576+
The input expression.
577+
578+
fskip: function
579+
The callback function that decides whether an expression should be skipped.
580+
581+
Returns
582+
-------
583+
expr : tvm.relay.Expr
584+
The output expression.
585+
"""
586+
return _ir_pass.eliminate_common_subexpr(expr, fskip)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*!
2+
* Copyright (c) 2019 by Contributors
3+
*
4+
* \file eliminate_common_subexpr.cc
5+
* \brief Combine common subexpressions.
6+
*
7+
* This is an optimization pass that eliminates common subexpressions. During the pass, it tries
8+
* to replace an expression with a previously appeared expression with the same input and
9+
* attributes. The fskip callback argument allows us to skip specific expressions.
10+
*/
11+
#include <tvm/relay/pass.h>
12+
#include <tvm/relay/expr_functor.h>
13+
#include <unordered_map>
14+
#include "./pattern_util.h"
15+
16+
namespace tvm {
17+
namespace relay {
18+
19+
class CommonSubexprEliminator : public ExprMutator {
20+
public:
21+
explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip): fskip_(fskip) {}
22+
23+
Expr VisitExpr_(const CallNode* call) final {
24+
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
25+
Expr new_expr = ExprMutator::VisitExpr_(call);
26+
const CallNode* new_call = new_expr.as<CallNode>();
27+
CHECK(new_call);
28+
const OpNode* op = new_call->op.as<OpNode>();
29+
AttrsEqual attrs_equal;
30+
31+
if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
32+
return new_expr;
33+
}
34+
if (fskip_ != nullptr && fskip_(new_expr)) {
35+
return new_expr;
36+
}
37+
38+
auto it = expr_map_.find(new_call->op);
39+
if (it != expr_map_.end()) {
40+
for (const CallNode* candidate : it->second) {
41+
bool is_equivalent = true;
42+
if (!attrs_equal(new_call->attrs, candidate->attrs)) {
43+
continue;
44+
}
45+
for (size_t i = 0; i < new_call->args.size(); i++) {
46+
if (!new_call->args[i].same_as(candidate->args[i]) &&
47+
!IsEqualScalar(new_call->args[i], candidate->args[i])) {
48+
is_equivalent = false;
49+
break;
50+
}
51+
}
52+
if (!is_equivalent) continue;
53+
return GetRef<Call>(candidate);
54+
}
55+
}
56+
expr_map_[new_call->op].push_back(new_call);
57+
return new_expr;
58+
}
59+
60+
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> expr_map_;
61+
runtime::TypedPackedFunc<bool(Expr)> fskip_;
62+
};
63+
64+
Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
65+
return CommonSubexprEliminator(callback)(expr);
66+
}
67+
68+
TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr")
69+
.set_body_typed<Expr(Expr, PackedFunc)>(EliminateCommonSubexpr);
70+
71+
} // namespace relay
72+
} // namespace tvm

src/relay/pass/pattern_util.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,21 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
191191
return ConstantNode::make(arr);
192192
}
193193

194+
/*!
195+
* \brief Check if two expressions are equal scalars.
196+
* \param a The expression to be checked.
197+
* \param b The expression to be checked
198+
* \return Whether two expressions are equal scalars.
199+
*/
200+
inline bool IsEqualScalar(const Expr& a, const Expr& b) {
201+
const auto* constant_a = a.as<ConstantNode>();
202+
const auto* constant_b = b.as<ConstantNode>();
203+
if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
204+
return false;
205+
}
206+
return AlphaEqual(a, b);
207+
}
208+
194209
inline Expr GetField(Expr t, size_t i) {
195210
return TupleGetItemNode::make(t, i);
196211
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Test eliminate common subexpr pass"""
2+
from tvm import relay
3+
from tvm.relay.op import register_alter_op_layout
4+
from tvm.relay import ir_pass
5+
6+
7+
def test_simple():
8+
def before():
9+
x = relay.var("x", shape=(1, 16))
10+
y1 = relay.nn.relu(x)
11+
y2 = relay.nn.relu(x)
12+
y1 = relay.add(y1, relay.const(1.0, "float32"))
13+
y2 = relay.add(y2, relay.const(1.0, "float32"))
14+
y = relay.add(y1, y2)
15+
f = relay.Function([x], y)
16+
return f
17+
18+
def expected():
19+
x = relay.var("x", shape=(1, 16))
20+
y = relay.nn.relu(x)
21+
y = relay.add(y, relay.const(1.0, "float32"))
22+
y = relay.add(y, y)
23+
f = relay.Function([x], y)
24+
return f
25+
26+
z = before()
27+
z = ir_pass.eliminate_common_subexpr(z)
28+
assert ir_pass.alpha_equal(z, expected())
29+
30+
31+
def test_callback():
32+
def before():
33+
x = relay.var("x", shape=(1, 16))
34+
y1 = relay.nn.relu(x)
35+
y2 = relay.nn.relu(x)
36+
y1 = relay.add(y1, relay.const(1.0, "float32"))
37+
y2 = relay.add(y2, relay.const(1.0, "float32"))
38+
y = relay.add(y1, y2)
39+
f = relay.Function([x], y)
40+
return f
41+
42+
def expected():
43+
x = relay.var("x", shape=(1, 16))
44+
y = relay.nn.relu(x)
45+
y1 = relay.add(y, relay.const(1.0, "float32"))
46+
y2 = relay.add(y, relay.const(1.0, "float32"))
47+
y = relay.add(y1, y2)
48+
f = relay.Function([x], y)
49+
return f
50+
51+
def fskip(expr):
52+
if isinstance(expr, relay.expr.Call) and expr.op.name == 'add':
53+
return True
54+
return False
55+
56+
z = before()
57+
z = ir_pass.eliminate_common_subexpr(z, fskip)
58+
assert ir_pass.alpha_equal(z, expected())
59+
60+
61+
if __name__ == "__main__":
62+
test_simple()
63+
test_callback()

0 commit comments

Comments
 (0)