-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay] Algebraic data types #2442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
f3b58c3
First pass on ADTs
MarisaKirisame 485889d
Add doc string for tag field
slyubomirsky 54e5196
Visit constructors in TypeVisitor for TypeData
slyubomirsky 671f615
Add to description of type call
slyubomirsky f9e48e8
Add type call to type solving and unification
slyubomirsky 94748c1
Make type mutator for typecall consistent with others (only create ne…
slyubomirsky 7b1126f
Ensure kindchecking can handle type calls and typedata
slyubomirsky f2d36d6
Fix bad nesting in module constructor
slyubomirsky 7dc26ef
Correctly construct call in typecall test
slyubomirsky 0c949c7
Add call override for ordinary vars (do we want this?)
slyubomirsky 9443bdc
Remove generalization hack from type inference because it was breakin…
slyubomirsky 33fc344
Check that there are no free type vars in exprs after inferring type
slyubomirsky c208071
Free var checks need module because of ADT constructors
slyubomirsky af7da11
Typecall test can't have unbound type var, make it global
slyubomirsky 4968b48
Uncomment tmap test and remove comments about failing to infer ret ty…
slyubomirsky 87f82b7
Put in dummy visits for ADTs in graph runtime codegen to placate pylint
slyubomirsky e0f4c08
Fix Relay type infer test module constructor
slyubomirsky b646740
Mark override for TypeCallNode in type solver
slyubomirsky e11f58e
Ensure free vars check treats patern vars as bound
slyubomirsky ec56e4a
Run interpreter in more ADT test cases
slyubomirsky 4f74545
Refactor kind check to return the kind, like typechecking
slyubomirsky 276e028
Fix invalid typecall in test
slyubomirsky 7963e7e
Add kind check to type inference, do not use nulls in func_type_annot…
slyubomirsky e2d6219
Redundant whitespace
slyubomirsky 40c7410
Make TypeData a separate kind
slyubomirsky 673bcd6
Make ADT handles a separate kind too, document calling convention better
slyubomirsky 5e61378
Remove nats and tree from prelude, move to test, document prelude
slyubomirsky 3db3c64
Restore and document nat and tree to prelude, add more tree tests
slyubomirsky 0041b46
Add alpha equality tests for match cases, fix variable binding bug
slyubomirsky d232beb
Add more kind check tests for ADTs
slyubomirsky 7c6d737
Add more tests for finding free or bound vars in match exprs
slyubomirsky 7322866
Add unification tests for type call
slyubomirsky 5f3a2f4
Update main() for alpha equality tests
slyubomirsky 90ee405
Add simple type inference test cases for match exprs and ADT construc…
slyubomirsky d4a54a1
Add more ADT interpreter tests
slyubomirsky 609f56e
Allow incomplete types when typechecking match cases
slyubomirsky 089813a
Type inference for pattern vars should use the type annotation if it'…
slyubomirsky ebec99c
Two more specific test cases for ADT matching
slyubomirsky 00963de
Add option ADT to prelude
slyubomirsky 47babdb
Fix broken reference to kind enum
slyubomirsky a37d927
Fix rebase snags
slyubomirsky 0c660aa
Do not attach checked types to constructors
slyubomirsky f5cec3e
More docstrings for module fields
slyubomirsky b56bc36
Use proper wrapper for indexing into module type data
slyubomirsky 6b8dbb8
checked_type for constructors is not populated
slyubomirsky 07ea915
Expand type call docstring
slyubomirsky b7cfc59
Rename PatternConstructor con field
slyubomirsky 8cd15f2
Use error reporter for pattern constructor case
slyubomirsky 1d9ae48
Condense error reporting in kind check, use error reporter
slyubomirsky acc2ec0
Expand docstrings and rename ADT fields
slyubomirsky 737514b
Rename 'option' ADT to 'optional' for consistency with Python
slyubomirsky 1a6e48a
Add various list iterators and utility functions to prelude
slyubomirsky 511a931
Add smoke tests for new iterators in prelude
slyubomirsky 3eeca4c
Add concat to prelude
slyubomirsky 868f76e
Add smoke test for concat
slyubomirsky 5a85aa8
Correct docstrings in prelude
slyubomirsky bd9bfc7
Ensure that type defs are written in module initialization
slyubomirsky de7ea6e
Various requested renamings
slyubomirsky 8909e18
Correct rebase snags
slyubomirsky ffd5c80
Add kind check tests for ref types
slyubomirsky 7a6eff0
Update the main() for kind checking tests
slyubomirsky File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,244 @@ | ||
| /*! | ||
| * Copyright (c) 2018 by Contributors | ||
| * \file tvm/relay/adt.h | ||
| * \brief Algebraic data types for Relay | ||
| */ | ||
| #ifndef TVM_RELAY_ADT_H_ | ||
| #define TVM_RELAY_ADT_H_ | ||
|
|
||
| #include <tvm/attrs.h> | ||
| #include <string> | ||
| #include <functional> | ||
| #include "./base.h" | ||
| #include "./type.h" | ||
| #include "./expr.h" | ||
|
|
||
| namespace tvm { | ||
| namespace relay { | ||
|
|
||
| /*! \brief Base type for declaring relay pattern. */ | ||
| class PatternNode : public RelayNode { | ||
| public: | ||
| static constexpr const char* _type_key = "relay.Pattern"; | ||
| TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node); | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Pattern is the base type for an ADT match pattern in Relay. | ||
| * | ||
| * Given an ADT value, a pattern might accept it and bind the pattern variable to some value | ||
| * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value. | ||
| * | ||
| * ADT pattern matching thus takes a list of values and binds to the first that accepts the value. | ||
| */ | ||
| class Pattern : public NodeRef { | ||
| public: | ||
| Pattern() {} | ||
| explicit Pattern(NodePtr<tvm::Node> p) : NodeRef(p) {} | ||
|
|
||
| using ContainerType = PatternNode; | ||
| }; | ||
|
|
||
| /*! \brief A wildcard pattern: Accepts all input and binds nothing. */ | ||
| class PatternWildcard; | ||
| /*! \brief PatternWildcard container node */ | ||
| class PatternWildcardNode : public PatternNode { | ||
| public: | ||
| PatternWildcardNode() {} | ||
|
|
||
| TVM_DLL static PatternWildcard make(); | ||
|
|
||
| void VisitAttrs(tvm::AttrVisitor* v) final { | ||
| v->Visit("span", &span); | ||
| } | ||
|
|
||
| static constexpr const char* _type_key = "relay.PatternWildcard"; | ||
| TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode); | ||
| }; | ||
|
|
||
| RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern); | ||
|
|
||
| /*! \brief A var pattern. Accept all input and bind to a var. */ | ||
| class PatternVar; | ||
| /*! \brief PatternVar container node */ | ||
| class PatternVarNode : public PatternNode { | ||
| public: | ||
| PatternVarNode() {} | ||
|
|
||
| /*! \brief Variable that stores the matched value. */ | ||
| tvm::relay::Var var; | ||
|
|
||
| TVM_DLL static PatternVar make(tvm::relay::Var var); | ||
|
|
||
| void VisitAttrs(tvm::AttrVisitor* v) final { | ||
| v->Visit("var", &var); | ||
| v->Visit("span", &span); | ||
| } | ||
|
|
||
| static constexpr const char* _type_key = "relay.PatternVar"; | ||
| TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode); | ||
| }; | ||
|
|
||
| RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern); | ||
|
|
||
| /*! | ||
| * \brief ADT constructor. | ||
| * Constructors compare by pointer equality. | ||
| */ | ||
| class Constructor; | ||
| /*! \brief Constructor container node. */ | ||
| class ConstructorNode : public ExprNode { | ||
| public: | ||
| /*! \brief The name (only a hint) */ | ||
| std::string name_hint; | ||
| /*! \brief Input to the constructor. */ | ||
| tvm::Array<Type> inputs; | ||
| /*! \brief The datatype the constructor will construct. */ | ||
| GlobalTypeVar belong_to; | ||
| /*! \brief Index in the table of constructors (set when the type is registered). */ | ||
| mutable int tag = -1; | ||
|
|
||
| ConstructorNode() {} | ||
|
|
||
| TVM_DLL static Constructor make(std::string name_hint, | ||
| tvm::Array<Type> inputs, | ||
| GlobalTypeVar belong_to); | ||
|
|
||
| void VisitAttrs(tvm::AttrVisitor* v) final { | ||
| v->Visit("name_hint", &name_hint); | ||
| v->Visit("inputs", &inputs); | ||
| v->Visit("belong_to", &belong_to); | ||
| v->Visit("tag", &tag); | ||
| v->Visit("span", &span); | ||
| v->Visit("_checked_type_", &checked_type_); | ||
| } | ||
|
|
||
| static constexpr const char* _type_key = "relay.Constructor"; | ||
| TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode); | ||
| }; | ||
|
|
||
| RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr); | ||
|
|
||
| /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ | ||
| class PatternConstructor; | ||
| /*! \brief PatternVar container node */ | ||
| class PatternConstructorNode : public PatternNode { | ||
| public: | ||
| /*! Constructor matched by the pattern. */ | ||
| Constructor constructor; | ||
| /*! Sub-patterns to match against each input to the constructor. */ | ||
| tvm::Array<Pattern> patterns; | ||
|
|
||
| PatternConstructorNode() {} | ||
|
|
||
| TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var); | ||
|
|
||
| void VisitAttrs(tvm::AttrVisitor* v) final { | ||
| v->Visit("constructor", &constructor); | ||
| v->Visit("patterns", &patterns); | ||
| v->Visit("span", &span); | ||
| } | ||
|
|
||
| static constexpr const char* _type_key = "relay.PatternConstructor"; | ||
| TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode); | ||
| }; | ||
|
|
||
| RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); | ||
|
|
||
| /*! | ||
| * \brief Stores all data for an Algebraic Data Type (ADT). | ||
| * | ||
| * In particular, it stores the handle (global type var) for an ADT | ||
| * and the constructors used to build it and is kept in the module. Note | ||
| * that type parameters are also indicated in the type data: this means that | ||
| * for any instance of an ADT, the type parameters must be indicated. That is, | ||
| * an ADT definition is treated as a type-level function, so an ADT handle | ||
| * must be wrapped in a TypeCall node that instantiates the type-level arguments. | ||
| * The kind checker enforces this. | ||
| */ | ||
| class TypeData; | ||
slyubomirsky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /*! \brief TypeData container node */ | ||
| class TypeDataNode : public TypeNode { | ||
| public: | ||
| /*! | ||
| * \brief The header is simply the name of the ADT. | ||
| * We adopt nominal typing for ADT definitions; | ||
| * that is, differently-named ADT definitions with same constructors | ||
| * have different types. | ||
| */ | ||
| GlobalTypeVar header; | ||
| /*! \brief The type variables (to allow for polymorphism). */ | ||
| tvm::Array<TypeVar> type_vars; | ||
| /*! \brief The constructors. */ | ||
| tvm::Array<Constructor> constructors; | ||
|
|
||
| void VisitAttrs(tvm::AttrVisitor* v) final { | ||
| v->Visit("header", &header); | ||
| v->Visit("type_vars", &type_vars); | ||
| v->Visit("constructors", &constructors); | ||
| v->Visit("span", &span); | ||
| } | ||
|
|
||
| TVM_DLL static TypeData make(GlobalTypeVar header, | ||
| tvm::Array<TypeVar> type_vars, | ||
| tvm::Array<Constructor> constructors); | ||
|
|
||
| static constexpr const char* _type_key = "relay.TypeData"; | ||
| TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode); | ||
| }; | ||
|
|
||
| RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type); | ||
|
|
||
| /*! \brief A clause in a match expression. */ | ||
| class Clause; | ||
| /*! \brief Clause container node. */ | ||
| class ClauseNode : public Node { | ||
| public: | ||
| /*! \brief The pattern the clause matches. */ | ||
| Pattern lhs; | ||
| /*! \brief The resulting value. */ | ||
| Expr rhs; | ||
|
|
||
| void VisitAttrs(tvm::AttrVisitor* v) final { | ||
| v->Visit("lhs", &lhs); | ||
| v->Visit("rhs", &rhs); | ||
| } | ||
|
|
||
| TVM_DLL static Clause make(Pattern lhs, Expr rhs); | ||
|
|
||
| static constexpr const char* _type_key = "relay.Clause"; | ||
| TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node); | ||
| }; | ||
|
|
||
| RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef); | ||
|
|
||
| /*! \brief ADT pattern matching exression. */ | ||
| class Match; | ||
| /*! \brief Match container node. */ | ||
| class MatchNode : public ExprNode { | ||
| public: | ||
| /*! \brief The input being deconstructed. */ | ||
| Expr data; | ||
|
|
||
| /*! \brief The match node clauses. */ | ||
| tvm::Array<Clause> clauses; | ||
|
|
||
| void VisitAttrs(tvm::AttrVisitor* v) final { | ||
| v->Visit("data", &data); | ||
| v->Visit("clause", &clauses); | ||
| v->Visit("span", &span); | ||
| v->Visit("_checked_type_", &checked_type_); | ||
| } | ||
|
|
||
| TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern); | ||
|
|
||
| static constexpr const char* _type_key = "relay.Match"; | ||
| TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); | ||
| }; | ||
|
|
||
| RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr); | ||
|
|
||
| } // namespace relay | ||
| } // namespace tvm | ||
|
|
||
| #endif // TVM_RELAY_ADT_H_ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.