Skip to content

Commit fed1c08

Browse files
joshpolltqchen
authored andcommitted
[Relay][Text Format] Fix Pretty Printing Annotations (#3041)
1 parent cdc9e85 commit fed1c08

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

src/relay/ir/pretty_printer.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,19 @@ class PrettyPrinter :
156156
*/
157157
Doc PrintOptionalInfo(const Expr& expr) {
158158
Doc doc;
159-
// additional information in comment.
160-
if (annotate_ != nullptr) {
161-
return doc << " /* " << annotate_(expr) << " */";
162-
} else if (expr->checked_type_.defined()) {
163-
return doc << " /* ty=" << Print(expr->checked_type()) << " */";
159+
// default annotations
160+
if (annotate_ == nullptr) {
161+
if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
162+
doc << " /* ty=" << Print(expr->checked_type()) << " */";
163+
}
164164
} else {
165-
return doc;
165+
std::string annotated_expr = annotate_(expr);
166+
if (annotated_expr != "") {
167+
doc << annotated_expr;
168+
}
166169
}
170+
171+
return doc;
167172
}
168173

169174
// indent a new body
@@ -361,9 +366,7 @@ class PrettyPrinter :
361366
printed_expr = VisitExpr(expr);
362367
}
363368

364-
if (expr.as<CallNode>()) {
365-
printed_expr << PrintOptionalInfo(expr);
366-
}
369+
printed_expr << PrintOptionalInfo(expr);
367370

368371
// add expr to doc
369372
if (expr.as<VarNode>()) {
@@ -409,8 +412,7 @@ class PrettyPrinter :
409412
}
410413
// default fall-back, record it as meta node.
411414
Doc doc;
412-
return doc << Print(GetRef<NodeRef>(op), true)
413-
<< PrintOptionalInfo(GetRef<Expr>(op));
415+
return doc << Print(GetRef<NodeRef>(op), true);
414416
}
415417

416418
Doc VisitExpr_(const TupleNode* op) final {

tests/python/relay/test_ir_text_printer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_env():
5252
assert "def @myf" in str(env)
5353
assert "add(%0, %0) /* ty=float32 */" in text
5454
assert "add(%0, %0) /* ty=float32 */" in str(env)
55-
show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
55+
show(env.astext(annotate=lambda x: str(x.checked_type.dtype) if type(x) == relay.Call else ""))
5656
show(text)
5757

5858

0 commit comments

Comments
 (0)