Skip to content

Commit b155ff3

Browse files
slyubomirskyWei Chen
authored andcommitted
[Relay] Algebraic data types (apache#2442)
* First pass on ADTs * Add doc string for tag field * Visit constructors in TypeVisitor for TypeData * Add to description of type call * Add type call to type solving and unification * Make type mutator for typecall consistent with others (only create new node if there's a change) * Ensure kindchecking can handle type calls and typedata * Fix bad nesting in module constructor * Correctly construct call in typecall test * Add call override for ordinary vars (do we want this?) * Remove generalization hack from type inference because it was breaking ADT constructors * Check that there are no free type vars in exprs after inferring type * Free var checks need module because of ADT constructors * Typecall test can't have unbound type var, make it global * Uncomment tmap test and remove comments about failing to infer ret type; those work now * Put in dummy visits for ADTs in graph runtime codegen to placate pylint * Fix Relay type infer test module constructor * Mark override for TypeCallNode in type solver * Ensure free vars check treats patern vars as bound * Run interpreter in more ADT test cases * Refactor kind check to return the kind, like typechecking * Fix invalid typecall in test * Add kind check to type inference, do not use nulls in func_type_annotation()! * Redundant whitespace * Make TypeData a separate kind * Make ADT handles a separate kind too, document calling convention better * Remove nats and tree from prelude, move to test, document prelude * Restore and document nat and tree to prelude, add more tree tests * Add alpha equality tests for match cases, fix variable binding bug * Add more kind check tests for ADTs * Add more tests for finding free or bound vars in match exprs * Add unification tests for type call * Update main() for alpha equality tests * Add simple type inference test cases for match exprs and ADT constructors * Add more ADT interpreter tests * Allow incomplete types when typechecking match cases * Type inference for pattern vars should use the type annotation if it's there * Two more specific test cases for ADT matching * Add option ADT to prelude * Fix broken reference to kind enum * Fix rebase snags * Do not attach checked types to constructors * More docstrings for module fields * Use proper wrapper for indexing into module type data * checked_type for constructors is not populated * Expand type call docstring * Rename PatternConstructor con field * Use error reporter for pattern constructor case * Condense error reporting in kind check, use error reporter * Expand docstrings and rename ADT fields * Rename 'option' ADT to 'optional' for consistency with Python * Add various list iterators and utility functions to prelude * Add smoke tests for new iterators in prelude * Add concat to prelude * Add smoke test for concat * Correct docstrings in prelude * Ensure that type defs are written in module initialization * Various requested renamings * Correct rebase snags * Add kind check tests for ref types * Update the main() for kind checking tests
1 parent 6b3f945 commit b155ff3

45 files changed

Lines changed: 3398 additions & 207 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/tvm/relay/adt.h

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
/*!
2+
* Copyright (c) 2018 by Contributors
3+
* \file tvm/relay/adt.h
4+
* \brief Algebraic data types for Relay
5+
*/
6+
#ifndef TVM_RELAY_ADT_H_
7+
#define TVM_RELAY_ADT_H_
8+
9+
#include <tvm/attrs.h>
10+
#include <string>
11+
#include <functional>
12+
#include "./base.h"
13+
#include "./type.h"
14+
#include "./expr.h"
15+
16+
namespace tvm {
17+
namespace relay {
18+
19+
/*! \brief Base type for declaring relay pattern. */
20+
class PatternNode : public RelayNode {
21+
public:
22+
static constexpr const char* _type_key = "relay.Pattern";
23+
TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node);
24+
};
25+
26+
/*!
27+
* \brief Pattern is the base type for an ADT match pattern in Relay.
28+
*
29+
* Given an ADT value, a pattern might accept it and bind the pattern variable to some value
30+
* (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
31+
*
32+
* ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
33+
*/
34+
class Pattern : public NodeRef {
35+
public:
36+
Pattern() {}
37+
explicit Pattern(NodePtr<tvm::Node> p) : NodeRef(p) {}
38+
39+
using ContainerType = PatternNode;
40+
};
41+
42+
/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
43+
class PatternWildcard;
44+
/*! \brief PatternWildcard container node */
45+
class PatternWildcardNode : public PatternNode {
46+
public:
47+
PatternWildcardNode() {}
48+
49+
TVM_DLL static PatternWildcard make();
50+
51+
void VisitAttrs(tvm::AttrVisitor* v) final {
52+
v->Visit("span", &span);
53+
}
54+
55+
static constexpr const char* _type_key = "relay.PatternWildcard";
56+
TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode);
57+
};
58+
59+
RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern);
60+
61+
/*! \brief A var pattern. Accept all input and bind to a var. */
62+
class PatternVar;
63+
/*! \brief PatternVar container node */
64+
class PatternVarNode : public PatternNode {
65+
public:
66+
PatternVarNode() {}
67+
68+
/*! \brief Variable that stores the matched value. */
69+
tvm::relay::Var var;
70+
71+
TVM_DLL static PatternVar make(tvm::relay::Var var);
72+
73+
void VisitAttrs(tvm::AttrVisitor* v) final {
74+
v->Visit("var", &var);
75+
v->Visit("span", &span);
76+
}
77+
78+
static constexpr const char* _type_key = "relay.PatternVar";
79+
TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode);
80+
};
81+
82+
RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern);
83+
84+
/*!
85+
* \brief ADT constructor.
86+
* Constructors compare by pointer equality.
87+
*/
88+
class Constructor;
89+
/*! \brief Constructor container node. */
90+
class ConstructorNode : public ExprNode {
91+
public:
92+
/*! \brief The name (only a hint) */
93+
std::string name_hint;
94+
/*! \brief Input to the constructor. */
95+
tvm::Array<Type> inputs;
96+
/*! \brief The datatype the constructor will construct. */
97+
GlobalTypeVar belong_to;
98+
/*! \brief Index in the table of constructors (set when the type is registered). */
99+
mutable int tag = -1;
100+
101+
ConstructorNode() {}
102+
103+
TVM_DLL static Constructor make(std::string name_hint,
104+
tvm::Array<Type> inputs,
105+
GlobalTypeVar belong_to);
106+
107+
void VisitAttrs(tvm::AttrVisitor* v) final {
108+
v->Visit("name_hint", &name_hint);
109+
v->Visit("inputs", &inputs);
110+
v->Visit("belong_to", &belong_to);
111+
v->Visit("tag", &tag);
112+
v->Visit("span", &span);
113+
v->Visit("_checked_type_", &checked_type_);
114+
}
115+
116+
static constexpr const char* _type_key = "relay.Constructor";
117+
TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode);
118+
};
119+
120+
RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr);
121+
122+
/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
123+
class PatternConstructor;
124+
/*! \brief PatternVar container node */
125+
class PatternConstructorNode : public PatternNode {
126+
public:
127+
/*! Constructor matched by the pattern. */
128+
Constructor constructor;
129+
/*! Sub-patterns to match against each input to the constructor. */
130+
tvm::Array<Pattern> patterns;
131+
132+
PatternConstructorNode() {}
133+
134+
TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);
135+
136+
void VisitAttrs(tvm::AttrVisitor* v) final {
137+
v->Visit("constructor", &constructor);
138+
v->Visit("patterns", &patterns);
139+
v->Visit("span", &span);
140+
}
141+
142+
static constexpr const char* _type_key = "relay.PatternConstructor";
143+
TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode);
144+
};
145+
146+
RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);
147+
148+
/*!
149+
* \brief Stores all data for an Algebraic Data Type (ADT).
150+
*
151+
* In particular, it stores the handle (global type var) for an ADT
152+
* and the constructors used to build it and is kept in the module. Note
153+
* that type parameters are also indicated in the type data: this means that
154+
* for any instance of an ADT, the type parameters must be indicated. That is,
155+
* an ADT definition is treated as a type-level function, so an ADT handle
156+
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
157+
* The kind checker enforces this.
158+
*/
159+
class TypeData;
160+
/*! \brief TypeData container node */
161+
class TypeDataNode : public TypeNode {
162+
public:
163+
/*!
164+
* \brief The header is simply the name of the ADT.
165+
* We adopt nominal typing for ADT definitions;
166+
* that is, differently-named ADT definitions with same constructors
167+
* have different types.
168+
*/
169+
GlobalTypeVar header;
170+
/*! \brief The type variables (to allow for polymorphism). */
171+
tvm::Array<TypeVar> type_vars;
172+
/*! \brief The constructors. */
173+
tvm::Array<Constructor> constructors;
174+
175+
void VisitAttrs(tvm::AttrVisitor* v) final {
176+
v->Visit("header", &header);
177+
v->Visit("type_vars", &type_vars);
178+
v->Visit("constructors", &constructors);
179+
v->Visit("span", &span);
180+
}
181+
182+
TVM_DLL static TypeData make(GlobalTypeVar header,
183+
tvm::Array<TypeVar> type_vars,
184+
tvm::Array<Constructor> constructors);
185+
186+
static constexpr const char* _type_key = "relay.TypeData";
187+
TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode);
188+
};
189+
190+
RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type);
191+
192+
/*! \brief A clause in a match expression. */
193+
class Clause;
194+
/*! \brief Clause container node. */
195+
class ClauseNode : public Node {
196+
public:
197+
/*! \brief The pattern the clause matches. */
198+
Pattern lhs;
199+
/*! \brief The resulting value. */
200+
Expr rhs;
201+
202+
void VisitAttrs(tvm::AttrVisitor* v) final {
203+
v->Visit("lhs", &lhs);
204+
v->Visit("rhs", &rhs);
205+
}
206+
207+
TVM_DLL static Clause make(Pattern lhs, Expr rhs);
208+
209+
static constexpr const char* _type_key = "relay.Clause";
210+
TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node);
211+
};
212+
213+
RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef);
214+
215+
/*! \brief ADT pattern matching exression. */
216+
class Match;
217+
/*! \brief Match container node. */
218+
class MatchNode : public ExprNode {
219+
public:
220+
/*! \brief The input being deconstructed. */
221+
Expr data;
222+
223+
/*! \brief The match node clauses. */
224+
tvm::Array<Clause> clauses;
225+
226+
void VisitAttrs(tvm::AttrVisitor* v) final {
227+
v->Visit("data", &data);
228+
v->Visit("clause", &clauses);
229+
v->Visit("span", &span);
230+
v->Visit("_checked_type_", &checked_type_);
231+
}
232+
233+
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern);
234+
235+
static constexpr const char* _type_key = "relay.Match";
236+
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
237+
};
238+
239+
RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr);
240+
241+
} // namespace relay
242+
} // namespace tvm
243+
244+
#endif // TVM_RELAY_ADT_H_

include/tvm/relay/expr_functor.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <tvm/node/ir_functor.h>
1111
#include <string>
1212
#include "./expr.h"
13+
#include "./adt.h"
1314
#include "./op.h"
1415
#include "./error.h"
1516

@@ -92,6 +93,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
9293
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
9394
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
9495
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
96+
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
97+
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
9598
virtual R VisitExprDefault_(const Node* op, Args...) {
9699
throw Error(std::string("Do not have a default for ") + op->type_key());
97100
}
@@ -114,6 +117,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
114117
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
115118
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
116119
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
120+
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
121+
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
117122
return vtable;
118123
}
119124
};
@@ -142,7 +147,11 @@ class ExprVisitor
142147
void VisitExpr_(const RefCreateNode* op) override;
143148
void VisitExpr_(const RefReadNode* op) override;
144149
void VisitExpr_(const RefWriteNode* op) override;
150+
void VisitExpr_(const ConstructorNode* op) override;
151+
void VisitExpr_(const MatchNode* op) override;
145152
virtual void VisitType(const Type& t);
153+
virtual void VisitClause(const Clause& c);
154+
virtual void VisitPattern(const Pattern& c);
146155

147156
protected:
148157
// Internal visiting counter
@@ -180,6 +189,9 @@ class ExprMutator
180189
Expr VisitExpr_(const RefCreateNode* op) override;
181190
Expr VisitExpr_(const RefReadNode* op) override;
182191
Expr VisitExpr_(const RefWriteNode* op) override;
192+
Expr VisitExpr_(const ConstructorNode* op) override;
193+
Expr VisitExpr_(const MatchNode* op) override;
194+
183195
/*!
184196
* \brief Used to visit the types inside of expressions.
185197
*
@@ -188,6 +200,8 @@ class ExprMutator
188200
* visitor for types which transform them appropriately.
189201
*/
190202
virtual Type VisitType(const Type& t);
203+
virtual Clause VisitClause(const Clause& c);
204+
virtual Pattern VisitPattern(const Pattern& c);
191205

192206
protected:
193207
/*! \brief Internal map used for memoization. */

include/tvm/relay/interpreter.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,28 @@ struct RefValueNode : ValueNode {
160160

161161
RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
162162

163+
/*! \brief An ADT constructor value. */
164+
class ConstructorValue;
165+
166+
struct ConstructorValueNode : ValueNode {
167+
Constructor constructor;
168+
169+
tvm::Array<Value> fields;
170+
171+
void VisitAttrs(tvm::AttrVisitor* v) final {
172+
v->Visit("constructor", &constructor);
173+
v->Visit("fields", &fields);
174+
}
175+
176+
TVM_DLL static ConstructorValue make(Constructor constructor,
177+
tvm::Array<Value> fields);
178+
179+
static constexpr const char* _type_key = "relay.ConstructorValue";
180+
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
181+
};
182+
183+
RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value);
184+
163185
} // namespace relay
164186
} // namespace tvm
165187
#endif // TVM_RELAY_INTERPRETER_H_

0 commit comments

Comments
 (0)