[RELAY][PASS] Common subexpression elimination#2639
Conversation
|
@kazum @ZihengJiang please help to take a look |
|
@jroesch can you manage this PR as per https://docs.tvm.ai/contribute/committer_guide.html, a good chance to test out your committer rights. |
| * \return Whether two expressions are equal scalars. | ||
| */ | ||
| inline bool IsEqualScalar(const Expr& a, const Expr& b) { | ||
| const auto* constant_a = a.as<ConstantNode>(); |
There was a problem hiding this comment.
constant_a is nullptr with relay.var("x", shape=(1, 16)) in your test script. Is this what you expect?
There was a problem hiding this comment.
yes, this function is intended to enable combining different Constant instance with the same value
|
With the below code, the result is as expected. from tvm import relay
x = relay.var("x", shape=(1, 16))
y1 = relay.add(x, relay.const(1.0, "float32"))
y2 = relay.add(x, relay.const(1.0, "float32"))
y = relay.add(y1, y2)
f = relay.Function([x], y)
f = relay.ir_pass.eliminate_common_subexpr(f)
print(f)However, when I changed the code a bit as follows, elimination did not work. Is this not a scope of this PR? from tvm import relay
x = relay.var("x", shape=(1, 16))
y1 = relay.add(relay.const(1.0, "float32"), x)
y2 = relay.add(relay.const(1.0, "float32"), x)
y = relay.add(y1, y2)
f = relay.Function([x], y)
f = relay.ir_pass.eliminate_common_subexpr(f)
print(f) |
|
@kazum yes, it is a limitation of current implementation |
| return GetRef<Call>(candidate); | ||
| } | ||
| } | ||
| expr_map_[new_call->args[0]].push_back(new_call); |
There was a problem hiding this comment.
Let me ask one more question. expr_map_ is a map from new_call->args[0] to new_call. Can we change it to a map from new_call->op to new_call? Then, this PR also handles the case of
#2639 (comment), doesn't it?
What I mean is like as follows:
auto it = expr_map_.find(new_call->op);
if (it != expr_map_.end()) {
for (const CallNode* candidate : it->second) {
bool is_equivalent = true;
if (!attrs_equal(new_call->attrs, candidate->attrs)) {
continue;
}
for (size_t i = 0; i < new_call->args.size(); i++) {
if (!new_call->args[i].same_as(candidate->args[i]) &&
!IsEqualScalar(new_call->args[i], candidate->args[i])) {
is_equivalent = false;
break;
}
}
if (!is_equivalent) continue;
return GetRef<Call>(candidate);
}
}
expr_map_[new_call->op].push_back(new_call);
return new_expr;There was a problem hiding this comment.
The reason I chose to map from new_call->args[0] is to avoid searching a long list of candidates. But yes you are right, on a second thought I think it is okay to map from op.
This is an optimization pass that eliminates common subexpressions. During the pass, it tries to replace an expression with a previously appeared expression with the same input and attributes. The fskip callback argument allows us to skip specific expressions.
cc @tqchen