Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ class Var : public LeafExpr {

TVM_DLL explicit Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);

VarNode* CopyOnWrite();
};

/*! \brief A sub-type of the variable node used to mark dataflow variables from
Expand Down Expand Up @@ -784,10 +785,10 @@ class BindingBlock : public ObjectRef {
public:
TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode);

BindingBlockNode* CopyOnWrite();
};

class DataflowBlock;
class DataflowBlockNode : public BindingBlockNode {
public:
bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const {
Expand Down
20 changes: 12 additions & 8 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -823,14 +823,18 @@ struct ObjectPtrEqual {
*
* \endcode
*/
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
ObjectName* CopyOnWrite() { \
ICHECK(data_ != nullptr); \
if (!data_.unique()) { \
auto n = make_object<ObjectName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<ObjectName*>(data_.get()); \
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
static_assert(ObjectName::_type_final, \
"TVM's CopyOnWrite may only be used for " \
"Object types that are declared as final, " \
"using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \
ObjectName* CopyOnWrite() { \
ICHECK(data_ != nullptr); \
if (!data_.unique()) { \
auto n = make_object<ObjectName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<ObjectName*>(data_.get()); \
}

// Implementations details below
Expand Down
38 changes: 38 additions & 0 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,25 @@ Var::Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span) {
data_ = std::move(n);
}

VarNode* Var::CopyOnWrite() {
// The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for
// Var, because it is the base class for `DataflowBlock`.
// If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the
// automatic implementation would erroneously convert from a
// `DataflowBlock` to a `Var`.
ICHECK(data_ != nullptr);
if (!data_.unique()) {
ObjectPtr<VarNode> node;
if (auto dataflow_var = as<DataflowVarNode>()) {
node = make_object<DataflowVarNode>(*dataflow_var);
} else {
node = make_object<VarNode>(*(operator->()));
}
ObjectPtr<Object>(std::move(node)).swap(data_);
}
return static_cast<VarNode*>(data_.get());
}

TVM_REGISTER_GLOBAL("relax.Var")
.set_body_typed([](String name_hint, Optional<StructInfo> struct_info_annotation, Span span) {
return Var(name_hint, struct_info_annotation, span);
Expand Down Expand Up @@ -473,6 +492,25 @@ BindingBlock::BindingBlock(Array<Binding> bindings, Span span) {
data_ = std::move(n);
}

BindingBlockNode* BindingBlock::CopyOnWrite() {
// The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for
// BindingBlock, because it is the base class for `DataflowBlock`.
// If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the
// automatic implementation would erroneously convert from a
// `DataflowBlock` to a `BindingBlock`.
ICHECK(data_ != nullptr);
if (!data_.unique()) {
ObjectPtr<BindingBlockNode> node;
if (auto dataflow_block = as<DataflowBlockNode>()) {
node = make_object<DataflowBlockNode>(*dataflow_block);
} else {
node = make_object<BindingBlockNode>(*(operator->()));
}
ObjectPtr<Object>(std::move(node)).swap(data_);
}
return static_cast<BindingBlockNode*>(data_.get());
}

TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array<Binding> bindings, Span span) {
return BindingBlock(bindings, span);
});
Expand Down