Skip to content

Commit ab76f30

Browse files
committed
- Fix bug with pattern var visit
1 parent 2ca953f commit ab76f30

File tree

3 files changed

+35
-45
lines changed

3 files changed

+35
-45
lines changed

src/relay/ir/indexed_graph.cc

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
/*!
2121
* \file src/relay/ir/indexed_graph.cc
22-
* \brief Utilities for creating Indexed Graphs.
22+
* \brief A graph representation of the dataflow in a Relay expression or Relay dataflow
23+
* pattern. Nodes in this graph capture generic dataflow inputs, outputs, dominators and control
24+
* flow scopes.
2325
*/
2426
#include "indexed_graph.h"
2527

@@ -34,33 +36,30 @@ namespace tvm {
3436
namespace relay {
3537

3638
std::string RefToSummary(const Expr& expr) {
37-
if (const auto* var_node = expr.as<VarNode>()) {
38-
return "%" + var_node->name_hint();
39-
} else if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
40-
return "@" + global_var_node->name_hint;
41-
} else if (const auto* op_node = expr.as<OpNode>()) {
42-
return op_node->name;
43-
} else if (expr.as<ConstantNode>()) {
44-
return "const";
45-
} else if (expr.as<FunctionNode>()) {
46-
return "fn";
47-
} else if (const auto* tuple_node = expr.as<TupleNode>()) {
48-
return "tuple(" + std::to_string(tuple_node->fields.size()) + ")";
49-
} else if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
50-
return "." + std::to_string(tuple_get_item_node->index);
51-
} else if (const auto* call_node = expr.as<CallNode>()) {
52-
return RefToSummary(call_node->op) + "(" + std::to_string(call_node->args.size()) + ")";
53-
} else if (expr.as<IfNode>()) {
54-
return "if";
55-
} else if (expr.as<ConstructorNode>()) {
56-
return "ctor";
57-
} else if (expr.as<MatchNode>()) {
58-
return "match";
59-
} else if (expr.as<LetNode>()) {
60-
return "let";
61-
} else {
62-
return "";
63-
}
39+
class Visitor : public ExprFunctor<std::string(const Expr&)> {
40+
std::string VisitExpr_(const VarNode* op) final { return "%" + op->name_hint(); }
41+
std::string VisitExpr_(const GlobalVarNode* op) final { return "@" + op->name_hint; }
42+
std::string VisitExpr_(const ConstantNode* op) final { return "const"; }
43+
std::string VisitExpr_(const TupleNode* op) final {
44+
return "tuple(" + std::to_string(op->fields.size()) + ")";
45+
}
46+
std::string VisitExpr_(const FunctionNode* op) final { return "fn"; }
47+
std::string VisitExpr_(const CallNode* op) final {
48+
return VisitExpr(op->op) + "(" + std::to_string(op->args.size()) + ")";
49+
}
50+
std::string VisitExpr_(const LetNode* op) final { return "let"; }
51+
std::string VisitExpr_(const IfNode* op) final { return "if"; }
52+
std::string VisitExpr_(const OpNode* op) final { return op->name; }
53+
std::string VisitExpr_(const TupleGetItemNode* op) final {
54+
return "." + std::to_string(op->index);
55+
}
56+
std::string VisitExpr_(const RefCreateNode* op) final { return "ref_create"; }
57+
std::string VisitExpr_(const RefReadNode* op) final { return "ref_read"; }
58+
std::string VisitExpr_(const RefWriteNode* op) final { return "ref_write"; }
59+
std::string VisitExpr_(const ConstructorNode* op) final { return "ctor"; }
60+
std::string VisitExpr_(const MatchNode* op) final { return "match"; }
61+
};
62+
return Visitor().VisitExpr(expr);
6463
}
6564

6665
std::string RefToSummary(const DFPattern& pattern) { return ""; }
@@ -140,7 +139,7 @@ std::unique_ptr<IndexedGraph<Expr>> CreateIndexedGraph(const Expr& expr) {
140139

141140
private:
142141
void VisitPattern_(const PatternVarNode* pattern_var_node) final {
143-
creator_->graph_->AddNode(pattern_var_node->var);
142+
creator_->VisitLeaf(pattern_var_node->var);
144143
}
145144

146145
Creator* creator_;
@@ -173,9 +172,9 @@ std::unique_ptr<IndexedGraph<Expr>> CreateIndexedGraph(const Expr& expr) {
173172
*/
174173
class Annotator : public ExprFunctor<void(const Expr&)> {
175174
public:
176-
Annotator(std::pair<std::unique_ptr<IndexedGraph<Expr>>,
177-
std::unique_ptr<std::unordered_set<const CallNode*>>>
178-
args)
175+
explicit Annotator(std::pair<std::unique_ptr<IndexedGraph<Expr>>,
176+
std::unique_ptr<std::unordered_set<const CallNode*>>>
177+
args)
179178
: graph_(std::move(args.first)), rec_calls_(std::move(args.second)) {}
180179

181180
std::unique_ptr<IndexedGraph<Expr>> Annotate() {

src/relay/ir/indexed_graph.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
* \brief A graph representation of the dataflow in a Relay expression or Relay dataflow
2323
* pattern. Nodes in this graph capture generic dataflow inputs, outputs, dominators and control
2424
* flow scopes.
25+
*
26+
* TODO(mbs): Rename to 'DataflowGraph'.
2527
*/
2628
#ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_
2729
#define TVM_RELAY_IR_INDEXED_GRAPH_H_
@@ -264,7 +266,7 @@ class IndexedGraph {
264266
ICHECK(node->node_ref_);
265267
auto itr = node_map_.find(node->node_ref_);
266268
ICHECK(itr != node_map_.end());
267-
ICHECK_EQ(itr->second, node);
269+
ICHECK_EQ(itr->second, node) << "at index " << index << " in:" << std::endl << ToString();
268270
// Inputs come before.
269271
for (size_t i = 0; i < node->inputs_.size(); ++i) {
270272
const Node* input = node->inputs_[i];

tests/python/relay/test_dataflow_pattern.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,18 +1458,6 @@ def test_partition_overused():
14581458
assert pattern.partition(out) == out
14591459

14601460

1461-
def test_partition_not_overused():
1462-
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
1463-
1464-
x = relay.var("input")
1465-
w = relay.var("weight")
1466-
conv2d = relay.op.nn.conv2d(x, w)
1467-
relu = relay.op.nn.relu(conv2d)
1468-
out = relu + relu
1469-
1470-
assert pattern.partition(out) == out
1471-
1472-
14731461
def test_partition_fuzzy_tuple():
14741462
x = relay.var("x")
14751463
y = relay.var("y")
@@ -1496,6 +1484,7 @@ def concat(*args):
14961484

14971485

14981486
def test_partition_fuzzy_function_args():
1487+
14991488
func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) + wildcard()
15001489
x = relay.var("x")
15011490
y = relay.var("y")

0 commit comments

Comments
 (0)