Skip to content

Commit 8e5b822

Browse files
committed
[RELAY][PASS] Common subexpression elimination
1 parent c59a78e commit 8e5b822

4 files changed

Lines changed: 169 additions & 0 deletions

File tree

python/tvm/relay/ir_pass.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,23 @@ def gradient(expr, mod=None):
533533
The output expression.
534534
"""
535535
return _ir_pass.first_order_gradient(expr, mod)
536+
537+
538+
def eliminate_common_subexpr(expr, fskip=None):
539+
"""
540+
Eliminate common subexpressions.
541+
542+
Parameters
543+
----------
544+
expr : tvm.relay.Expr
545+
The input expression.
546+
547+
fskip: function
548+
The callback function that decides whether an expression should be skipped.
549+
550+
Returns
551+
-------
552+
expr : tvm.relay.Expr
553+
The output expression.
554+
"""
555+
return _ir_pass.eliminate_common_subexpr(expr, fskip)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*!
2+
* Copyright (c) 2019 by Contributors
3+
*
4+
* \file eliminate_common_subexpr.cc
5+
* \brief Combine common subexpressions.
6+
*
7+
* This is an optimization pass that eliminates common subexpressions. During the pass, it tries
8+
* to replace an expression with a previously appeared expression with the same input and
9+
* attributes. The fskip callback argument allows us to skip specific expressions.
10+
*/
11+
#include <tvm/relay/pass.h>
12+
#include <tvm/relay/expr_functor.h>
13+
#include <unordered_map>
14+
#include "./pattern_util.h"
15+
16+
namespace tvm {
17+
namespace relay {
18+
19+
class CommonSubexprEliminator : public ExprMutator {
20+
public:
21+
explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip): fskip_(fskip) {}
22+
23+
Expr VisitExpr_(const CallNode* call) final {
24+
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
25+
Expr new_expr = ExprMutator::VisitExpr_(call);
26+
const CallNode* new_call = new_expr.as<CallNode>();
27+
CHECK(new_call);
28+
const OpNode* op = new_call->op.as<OpNode>();
29+
AttrsEqual attrs_equal;
30+
31+
if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
32+
return new_expr;
33+
}
34+
if (fskip_ != nullptr && fskip_(new_expr)) {
35+
return new_expr;
36+
}
37+
38+
auto it = expr_map_.find(new_call->args[0]);
39+
if (it != expr_map_.end()) {
40+
for (const CallNode* candidate : it->second) {
41+
bool is_equivalent = true;
42+
if (!new_call->op.same_as(candidate->op)) continue;
43+
for (size_t i = 0; i < new_call->args.size(); i++) {
44+
if (!new_call->args[i].same_as(candidate->args[i]) &&
45+
!IsEqualScalar(new_call->args[i], candidate->args[i]) &&
46+
!attrs_equal(new_call->attrs, candidate->attrs)) {
47+
is_equivalent = false;
48+
break;
49+
}
50+
}
51+
if (!is_equivalent) continue;
52+
return GetRef<Call>(candidate);
53+
}
54+
}
55+
expr_map_[new_call->args[0]].push_back(new_call);
56+
return new_expr;
57+
}
58+
59+
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> expr_map_;
60+
runtime::TypedPackedFunc<bool(Expr)> fskip_;
61+
};
62+
63+
Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
64+
return CommonSubexprEliminator(callback)(expr);
65+
}
66+
67+
TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr")
68+
.set_body_typed<Expr(Expr, PackedFunc)>(EliminateCommonSubexpr);
69+
70+
} // namespace relay
71+
} // namespace tvm

src/relay/pass/pattern_util.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,21 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
192192
return ConstantNode::make(arr);
193193
}
194194

195+
/*!
196+
* \brief Check if two expressions are equal scalars.
197+
* \param a The expression to be checked.
198+
* \param b The expression to be checked
199+
* \return Whether two expressions are equal scalars.
200+
*/
201+
inline bool IsEqualScalar(const Expr& a, const Expr& b) {
202+
const auto* constant_a = a.as<ConstantNode>();
203+
const auto* constant_b = b.as<ConstantNode>();
204+
if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
205+
return false;
206+
}
207+
return AlphaEqual(a, b);
208+
}
209+
195210
inline Expr GetField(Expr t, size_t i) {
196211
return TupleGetItemNode::make(t, i);
197212
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Test eliminate common subexpr pass"""
2+
from tvm import relay
3+
from tvm.relay.op import register_alter_op_layout
4+
from tvm.relay import ir_pass
5+
6+
7+
def test_simple():
8+
def before():
9+
x = relay.var("x", shape=(1, 16))
10+
y1 = relay.nn.relu(x)
11+
y2 = relay.nn.relu(x)
12+
y1 = relay.add(y1, relay.const(1.0, "float32"))
13+
y2 = relay.add(y2, relay.const(1.0, "float32"))
14+
y = relay.add(y1, y2)
15+
f = relay.Function([x], y)
16+
return f
17+
18+
def expected():
19+
x = relay.var("x", shape=(1, 16))
20+
y = relay.nn.relu(x)
21+
y = relay.add(y, relay.const(1.0, "float32"))
22+
y = relay.add(y, y)
23+
f = relay.Function([x], y)
24+
return f
25+
26+
z = before()
27+
z = ir_pass.eliminate_common_subexpr(z)
28+
assert ir_pass.alpha_equal(z, expected())
29+
30+
31+
def test_callback():
32+
def before():
33+
x = relay.var("x", shape=(1, 16))
34+
y1 = relay.nn.relu(x)
35+
y2 = relay.nn.relu(x)
36+
y1 = relay.add(y1, relay.const(1.0, "float32"))
37+
y2 = relay.add(y2, relay.const(1.0, "float32"))
38+
y = relay.add(y1, y2)
39+
f = relay.Function([x], y)
40+
return f
41+
42+
def expected():
43+
x = relay.var("x", shape=(1, 16))
44+
y = relay.nn.relu(x)
45+
y1 = relay.add(y, relay.const(1.0, "float32"))
46+
y2 = relay.add(y, relay.const(1.0, "float32"))
47+
y = relay.add(y1, y2)
48+
f = relay.Function([x], y)
49+
return f
50+
51+
def fskip(expr):
52+
if isinstance(expr, relay.expr.Call) and expr.op.name == 'add':
53+
return True
54+
return False
55+
56+
z = before()
57+
z = ir_pass.eliminate_common_subexpr(z, fskip)
58+
assert ir_pass.alpha_equal(z, expected())
59+
60+
61+
if __name__ == "__main__":
62+
test_simple()
63+
test_callback()

0 commit comments

Comments
 (0)