Skip to content

Commit 78ca6fc

Browse files
authored
[NODE][REFACTOR] Refactor reflection system in node. (#4189)
* [NODE][REFACTOR] Refactor reflection system in node. - Removed the old Node, Node is now just an alias of runtime::Object - Introduce ReflectionVTable, a new columnar dispatcher to support reflection - This allows us to remove vtable from most node objects - The VisitAttrs are registered via TVM_RESGITER_NODE_TYPE, they are no longer virtual. - Consolidated serialization and reflection features into node. * Explicit type qualification when calling destructor. * Fix SPIRV, more comments
1 parent 324a960 commit 78ca6fc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1105
-941
lines changed

include/tvm/api_registry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class EnvFuncNode : public Node {
5858
/*! \brief constructor */
5959
EnvFuncNode() {}
6060

61-
void VisitAttrs(AttrVisitor* v) final {
61+
void VisitAttrs(AttrVisitor* v) {
6262
v->Visit("name", &name);
6363
}
6464

include/tvm/arithmetic.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class ConstIntBoundNode : public Node {
6060
int64_t min_value;
6161
int64_t max_value;
6262

63-
void VisitAttrs(tvm::AttrVisitor* v) final {
63+
void VisitAttrs(tvm::AttrVisitor* v) {
6464
v->Visit("min_value", &min_value);
6565
v->Visit("max_value", &max_value);
6666
}
@@ -162,7 +162,7 @@ class ModularSetNode : public Node {
162162
/*! \brief The base */
163163
int64_t base;
164164

165-
void VisitAttrs(tvm::AttrVisitor* v) final {
165+
void VisitAttrs(tvm::AttrVisitor* v) {
166166
v->Visit("coeff", &coeff);
167167
v->Visit("base", &base);
168168
}
@@ -351,7 +351,7 @@ enum SignType {
351351
*/
352352
struct IntSetNode : public Node {
353353
static constexpr const char* _type_key = "IntSet";
354-
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
354+
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object);
355355
};
356356

357357
/*!

include/tvm/attrs.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class AttrFieldInfoNode : public Node {
115115
/*! \brief detailed description of the type */
116116
std::string description;
117117

118-
void VisitAttrs(AttrVisitor* v) final {
118+
void VisitAttrs(AttrVisitor* v) {
119119
v->Visit("name", &name);
120120
v->Visit("type_info", &type_info);
121121
v->Visit("description", &description);
@@ -197,7 +197,7 @@ class AttrsHash {
197197
size_t operator()(const std::string& value) const {
198198
return std::hash<std::string>()(value);
199199
}
200-
size_t operator()(const Type& value) const {
200+
size_t operator()(const DataType& value) const {
201201
return std::hash<int>()(
202202
static_cast<int>(value.code()) |
203203
(static_cast<int>(value.bits()) << 8) |
@@ -221,6 +221,8 @@ class BaseAttrsNode : public Node {
221221
public:
222222
using TVMArgs = runtime::TVMArgs;
223223
using TVMRetValue = runtime::TVMRetValue;
224+
// visit function
225+
virtual void VisitAttrs(AttrVisitor* v) {}
224226
/*!
225227
* \brief Initialize the attributes by sequence of arguments
226228
* \param args The postional arguments in the form
@@ -753,12 +755,12 @@ class AttrNonDefaultVisitor {
753755
template<typename DerivedType>
754756
class AttrsNode : public BaseAttrsNode {
755757
public:
756-
void VisitAttrs(AttrVisitor* v) final {
758+
void VisitAttrs(AttrVisitor* v) {
757759
::tvm::detail::AttrNormalVisitor vis(v);
758760
self()->__VisitAttrs__(vis);
759761
}
760762

761-
void VisitNonDefaultAttrs(AttrVisitor* v) final {
763+
void VisitNonDefaultAttrs(AttrVisitor* v) {
762764
::tvm::detail::AttrNonDefaultVisitor vis(v);
763765
self()->__VisitAttrs__(vis);
764766
}

include/tvm/base.h

Lines changed: 1 addition & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -19,89 +19,16 @@
1919

2020
/*!
2121
* \file tvm/base.h
22-
* \brief Defines the base data structure
22+
* \brief Base utilities
2323
*/
2424
#ifndef TVM_BASE_H_
2525
#define TVM_BASE_H_
2626

2727
#include <dmlc/logging.h>
28-
#include <dmlc/registry.h>
29-
#include <tvm/node/node.h>
30-
#include <string>
31-
#include <memory>
32-
#include <functional>
3328
#include <utility>
34-
#include "runtime/registry.h"
3529

3630
namespace tvm {
3731

38-
using ::tvm::Node;
39-
using ::tvm::NodeRef;
40-
using ::tvm::AttrVisitor;
41-
42-
/*!
43-
* \brief Macro to define common node ref methods.
44-
* \param TypeName The name of the NodeRef.
45-
* \param BaseTypeName The Base type.
46-
* \param NodeName The node container type.
47-
*/
48-
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
49-
TypeName() {} \
50-
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
51-
: BaseTypeName(n) {} \
52-
const NodeName* operator->() const { \
53-
return static_cast<const NodeName*>(data_.get()); \
54-
} \
55-
operator bool() const { return this->defined(); } \
56-
using ContainerType = NodeName;
57-
58-
/*!
59-
* \brief Macro to define CopyOnWrite function in a NodeRef.
60-
* \param NodeName The Type of the Node.
61-
*
62-
* CopyOnWrite will generate a unique copy of the internal node.
63-
* The node will be copied if it is referenced by multiple places.
64-
* The function returns the raw pointer to the node to allow modification
65-
* of the content.
66-
*
67-
* \code
68-
*
69-
* MyCOWNodeRef ref, ref2;
70-
* ref2 = ref;
71-
* ref.CopyOnWrite()->value = new_value;
72-
* assert(ref2->value == old_value);
73-
* assert(ref->value == new_value);
74-
*
75-
* \endcode
76-
*/
77-
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
78-
NodeName* CopyOnWrite() { \
79-
CHECK(data_ != nullptr); \
80-
if (!data_.unique()) { \
81-
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
82-
ObjectPtr<Object>(std::move(n)).swap(data_); \
83-
} \
84-
return static_cast<NodeName*>(data_.get()); \
85-
}
86-
87-
/*! \brief Macro to make it easy to define node ref type given node */
88-
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
89-
class TypeName : public ::tvm::NodeRef { \
90-
public: \
91-
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
92-
}; \
93-
94-
/*!
95-
* \brief Macro to make it easy to define node ref type that
96-
* has a CopyOnWrite member function.
97-
*/
98-
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
99-
class TypeName : public BaseType { \
100-
public: \
101-
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
102-
TVM_DEFINE_NODE_REF_COW(NodeName); \
103-
};
104-
10532
/*!
10633
* \brief RAII wrapper function to enter and exit a context object
10734
* similar to python's with syntax.
@@ -146,100 +73,6 @@ class With {
14673
ContextType ctx_;
14774
};
14875

149-
/*!
150-
* \brief save the node as well as all the node it depends on as json.
151-
* This can be used to serialize any TVM object
152-
*
153-
* \return the string representation of the node.
154-
*/
155-
std::string SaveJSON(const NodeRef& node);
156-
157-
/*!
158-
* \brief Internal implementation of LoadJSON
159-
* Load tvm Node object from json and return a shared_ptr of Node.
160-
* \param json_str The json string to load from.
161-
*
162-
* \return The shared_ptr of the Node.
163-
*/
164-
ObjectPtr<Object> LoadJSON_(std::string json_str);
165-
166-
/*!
167-
* \brief Load the node from json string.
168-
* This can be used to deserialize any TVM object.
169-
*
170-
* \param json_str The json string to load from.
171-
*
172-
* \tparam NodeType the nodetype
173-
*
174-
* \code
175-
* Expr e = LoadJSON<Expr>(json_str);
176-
* \endcode
177-
*/
178-
template<typename NodeType,
179-
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
180-
inline NodeType LoadJSON(const std::string& json_str) {
181-
return NodeType(LoadJSON_(json_str));
182-
}
183-
184-
/*!
185-
* \brief Registry entry for NodeFactory.
186-
*
187-
* There are two types of Nodes that can be serialized.
188-
* The normal node requires a registration a creator function that
189-
* constructs an empty Node of the corresponding type.
190-
*
191-
* The global singleton(e.g. global operator) where only global_key need to be serialized,
192-
* in this case, FGlobalKey need to be defined.
193-
*/
194-
struct NodeFactoryReg {
195-
/*!
196-
* \brief creator function.
197-
* \param global_key Key that identifies a global single object.
198-
* If this is not empty then FGlobalKey
199-
* \return The created function.
200-
*/
201-
using FCreate = std::function<NodePtr<Node>(const std::string& global_key)>;
202-
/*!
203-
* \brief Global key function, only needed by global objects.
204-
* \param node The node pointer.
205-
* \return node The global key to the node.
206-
*/
207-
using FGlobalKey = std::function<std::string(const Node* node)>;
208-
/*! \brief registered name */
209-
std::string name;
210-
/*!
211-
* \brief The creator function
212-
*/
213-
FCreate fcreator = nullptr;
214-
/*!
215-
* \brief The global key function.
216-
*/
217-
FGlobalKey fglobal_key = nullptr;
218-
// setter of creator
219-
NodeFactoryReg& set_creator(FCreate f) { // NOLINT(*)
220-
this->fcreator = f;
221-
return *this;
222-
}
223-
// setter of creator
224-
NodeFactoryReg& set_global_key(FGlobalKey f) { // NOLINT(*)
225-
this->fglobal_key = f;
226-
return *this;
227-
}
228-
// global registry singleton
229-
TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry();
230-
};
231-
232-
/*!
233-
* \brief Register a Node type
234-
* \note This is necessary to enable serialization of the Node.
235-
*/
236-
#define TVM_REGISTER_NODE_TYPE(TypeName) \
237-
TVM_REGISTER_OBJECT_TYPE(TypeName); \
238-
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
239-
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
240-
.set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })
241-
242-
24376
#define TVM_STRINGIZE_DETAIL(x) #x
24477
#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x)
24578
#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__))

include/tvm/buffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class BufferNode : public Node {
135135
/*! \brief constructor */
136136
BufferNode() {}
137137

138-
void VisitAttrs(AttrVisitor* v) final {
138+
void VisitAttrs(AttrVisitor* v) {
139139
v->Visit("data", &data);
140140
v->Visit("dtype", &dtype);
141141
v->Visit("shape", &shape);

include/tvm/build_module.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class TargetNode : public Node {
6161
/*! \return the full device string to pass to codegen::Build */
6262
TVM_DLL const std::string& str() const;
6363

64-
void VisitAttrs(AttrVisitor* v) final {
64+
void VisitAttrs(AttrVisitor* v) {
6565
v->Visit("target_name", &target_name);
6666
v->Visit("device_name", &device_name);
6767
v->Visit("device_type", &device_type);
@@ -229,7 +229,7 @@ class BuildConfigNode : public Node {
229229
/*! \brief Whether to disable loop vectorization. */
230230
bool disable_vectorize = false;
231231

232-
void VisitAttrs(AttrVisitor* v) final {
232+
void VisitAttrs(AttrVisitor* v) {
233233
v->Visit("data_alignment", &data_alignment);
234234
v->Visit("offset_factor", &offset_factor);
235235
v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
@@ -473,6 +473,8 @@ class GenericFuncNode : public Node {
473473
/* \brief map from keys to registered functions */
474474
std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
475475

476+
void VisitAttrs(AttrVisitor* v) {}
477+
476478
static constexpr const char* _type_key = "GenericFunc";
477479
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
478480
};

include/tvm/channel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct ChannelNode : public Node {
5454
/*! \brief default data type in read/write */
5555
Type dtype;
5656
// visit all attributes
57-
void VisitAttrs(AttrVisitor* v) final {
57+
void VisitAttrs(AttrVisitor* v) {
5858
v->Visit("handle_var", &handle_var);
5959
v->Visit("dtype", &dtype);
6060
}

include/tvm/data_layout.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class LayoutNode : public Node {
104104
*/
105105
Array<IterVar> axes;
106106

107-
void VisitAttrs(AttrVisitor* v) final {
107+
void VisitAttrs(AttrVisitor* v) {
108108
v->Visit("name", &name);
109109
v->Visit("axes", &axes);
110110
}
@@ -325,7 +325,7 @@ class BijectiveLayoutNode : public Node {
325325
/*! \brief The destination layout */
326326
Layout dst_layout;
327327

328-
void VisitAttrs(AttrVisitor* v) final {
328+
void VisitAttrs(AttrVisitor* v) {
329329
v->Visit("src_layout", &src_layout);
330330
v->Visit("dst_layout", &dst_layout);
331331
v->Visit("forward_rule", &forward_rule);

include/tvm/expr.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
#include <string>
2828
#include <algorithm>
2929
#include <unordered_map>
30+
#include <iostream>
3031
#include "base.h"
3132
#include "dtype.h"
33+
#include "node/node.h"
3234
#include "node/container.h"
3335
#include "node/ir_functor.h"
3436
#include "runtime/c_runtime_api.h"
@@ -110,7 +112,7 @@ class Variable : public ExprNode {
110112

111113
static Var make(DataType dtype, std::string name_hint);
112114

113-
void VisitAttrs(AttrVisitor* v) final {
115+
void VisitAttrs(AttrVisitor* v) {
114116
v->Visit("dtype", &type);
115117
v->Visit("name", &name_hint);
116118
}
@@ -164,7 +166,7 @@ class IntImm : public ExprNode {
164166
/*! \brief the Internal value. */
165167
int64_t value;
166168

167-
void VisitAttrs(AttrVisitor* v) final {
169+
void VisitAttrs(AttrVisitor* v) {
168170
v->Visit("dtype", &type);
169171
v->Visit("value", &value);
170172
}
@@ -230,7 +232,7 @@ class RangeNode : public Node {
230232
RangeNode() {}
231233
RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
232234

233-
void VisitAttrs(AttrVisitor* v) final {
235+
void VisitAttrs(AttrVisitor* v) {
234236
v->Visit("min", &min);
235237
v->Visit("extent", &extent);
236238
}
@@ -406,7 +408,7 @@ class IterVarNode : public Node {
406408
*/
407409
std::string thread_tag;
408410

409-
void VisitAttrs(AttrVisitor* v) final {
411+
void VisitAttrs(AttrVisitor* v) {
410412
v->Visit("dom", &dom);
411413
v->Visit("var", &var);
412414
v->Visit("iter_type", &iter_type);
@@ -490,7 +492,7 @@ class IRPrinter {
490492
};
491493

492494
// default print function for all nodes
493-
inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
495+
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
494496
IRPrinter(os).Print(n);
495497
return os;
496498
}

0 commit comments

Comments
 (0)