diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index af2db582d604..8620ad80bda7 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -33,6 +33,7 @@ #include #include +#include #include #include @@ -115,7 +116,7 @@ namespace relax { * use this class or logic of a similar kind. */ template -class NestedMsg : public ObjectRef { +class NestedMsg { public: // default constructors. NestedMsg() = default; @@ -123,12 +124,6 @@ class NestedMsg : public ObjectRef { NestedMsg(NestedMsg&&) = default; NestedMsg& operator=(const NestedMsg&) = default; NestedMsg& operator=(NestedMsg&&) = default; - /*! - * \brief Construct from an ObjectPtr - * whose type already satisfies the constraint - * \param ptr - */ - explicit NestedMsg(ObjectPtr ptr) : ObjectRef(ptr) {} /*! \brief Nullopt handling */ NestedMsg(std::nullopt_t) {} // NOLINT(*) // nullptr handling. @@ -140,16 +135,17 @@ class NestedMsg : public ObjectRef { } // normal value handling. NestedMsg(T other) // NOLINT(*) - : ObjectRef(std::move(other)) {} + : data_(std::move(other)) {} NestedMsg& operator=(T other) { - ObjectRef::operator=(std::move(other)); + data_ = std::move(other); return *this; } // Array> handling NestedMsg(Array, void> other) // NOLINT(*) - : ObjectRef(std::move(other)) {} + : data_(other) {} + NestedMsg& operator=(Array, void> other) { - ObjectRef::operator=(std::move(other)); + data_ = std::move(other); return *this; } @@ -170,13 +166,16 @@ class NestedMsg : public ObjectRef { bool operator!=(std::nullptr_t) const { return data_ != nullptr; } /*! \return Whether the nested message is not-null leaf value */ - bool IsLeaf() const { return data_ != nullptr && data_->IsInstance(); } + bool IsLeaf() const { + return data_.type_index() != ffi::TypeIndex::kTVMFFINone && + data_.type_index() != ffi::TypeIndex::kTVMFFIArray; + } /*! \return Whether the nested message is null */ - bool IsNull() const { return data_ == nullptr; } + bool IsNull() const { return data_.type_index() == ffi::TypeIndex::kTVMFFINone; } /*! \return Whether the nested message is nested */ - bool IsNested() const { return data_ != nullptr && data_->IsInstance(); } + bool IsNested() const { return data_.type_index() == ffi::TypeIndex::kTVMFFIArray; } /*! * \return The underlying leaf value. @@ -184,7 +183,7 @@ class NestedMsg : public ObjectRef { */ T LeafValue() const { ICHECK(IsLeaf()); - return T(data_); + return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck(data_); } /*! @@ -192,16 +191,15 @@ class NestedMsg : public ObjectRef { * \note This checks if the underlying data type is array. */ Array, void> NestedArray() const { - ICHECK(IsNested()); - return Array, void>(data_); + return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck, void>>(data_); } - using ContainerType = Object; - using LeafContainerType = typename T::ContainerType; - - static_assert(std::is_base_of::value, "NestedMsg is only defined for ObjectRef."); - - static constexpr bool _type_is_nullable = true; + private: + ffi::Any data_; + // private constructor + explicit NestedMsg(ffi::Any data) : data_(data) {} + template + friend struct ffi::TypeTraits; }; /*! @@ -598,5 +596,83 @@ StructInfo TransformTupleLeaf(StructInfo sinfo, std::array, N> msgs } } // namespace relax + +namespace ffi { + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + TVM_FFI_INLINE static void CopyToAnyView(const relax::NestedMsg& src, TVMFFIAny* result) { + *result = ffi::AnyView(src.data_).CopyToTVMFFIAny(); + } + + TVM_FFI_INLINE static void MoveToAny(relax::NestedMsg src, TVMFFIAny* result) { + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_)); + } + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + + static bool CheckAnyStrict(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return true; + } + if (TypeTraits::CheckAnyStrict(src)) { + return true; + } + if (src->type_index == TypeIndex::kTVMFFIArray) { + const ffi::ArrayObj* array = reinterpret_cast(src->v_obj); + for (size_t i = 0; i < array->size(); ++i) { + const Any& any_v = (*array)[i]; + if (!details::AnyUnsafe::CheckAnyStrict>(any_v)) return false; + } + } + return true; + } + + TVM_FFI_INLINE static relax::NestedMsg CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return relax::NestedMsg(Any(AnyView::CopyFromTVMFFIAny(*src))); + } + + TVM_FFI_INLINE static relax::NestedMsg MoveFromAnyAfterCheck(TVMFFIAny* src) { + return relax::NestedMsg(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src))); + } + + static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { + if (CheckAnyStrict(src)) { + return CopyFromAnyViewAfterCheck(src); + } + // slow path run conversion + if (src->type_index == TypeIndex::kTVMFFINone) { + return relax::NestedMsg(std::nullopt); + } + if (auto opt_value = TypeTraits::TryCastFromAnyView(src)) { + return relax::NestedMsg(*std::move(opt_value)); + } + if (src->type_index == TypeIndex::kTVMFFIArray) { + const ArrayObj* n = reinterpret_cast(src->v_obj); + Array> result; + result.reserve(n->size()); + for (size_t i = 0; i < n->size(); i++) { + const Any& any_v = (*n)[i]; + if (auto opt_v = any_v.try_cast>()) { + result.push_back(*std::move(opt_v)); + } else { + return std::nullopt; + } + } + return relax::NestedMsg(result); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { + return "NestedMsg<" + details::Type2Str::v() + ">"; + } +}; +} // namespace ffi } // namespace tvm #endif // TVM_RELAX_NESTED_MSG_H_ diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index d552dae8f754..644a80664fe1 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -53,7 +53,7 @@ TEST(NestedMsg, Basic) { EXPECT_ANY_THROW(msg.LeafValue()); auto arr = msg.NestedArray(); - EXPECT_TRUE(arr[0].same_as(x)); + EXPECT_TRUE(arr[0].LeafValue().same_as(x)); EXPECT_TRUE(arr[1] == nullptr); EXPECT_TRUE(arr[1].IsNull()); @@ -72,13 +72,24 @@ TEST(NestedMsg, Basic) { EXPECT_TRUE(a0.IsNested()); auto t0 = a0.NestedArray()[1]; EXPECT_TRUE(t0.IsNested()); - EXPECT_TRUE(t0.NestedArray()[2].same_as(y)); + EXPECT_TRUE(t0.NestedArray()[2].LeafValue().same_as(y)); // assign leaf a0 = x; EXPECT_TRUE(a0.IsLeaf()); - EXPECT_TRUE(a0.same_as(x)); + EXPECT_TRUE(a0.LeafValue().same_as(x)); +} + +TEST(NestedMsg, IntAndAny) { + NestedMsg msg({1, std::nullopt, 2}); + Any any_msg = msg; + NestedMsg msg2 = any_msg.cast>(); + + EXPECT_TRUE(msg2.IsNested()); + EXPECT_EQ(msg2.NestedArray()[0].LeafValue(), 1); + EXPECT_TRUE(msg2.NestedArray()[1].IsNull()); + EXPECT_EQ(msg2.NestedArray()[2].LeafValue(), 2); } TEST(NestedMsg, ForEachLeaf) { @@ -174,13 +185,13 @@ TEST(NestedMsg, MapAndDecompose) { DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg msg) { if (value.same_as(x)) { - EXPECT_TRUE(msg.same_as(c0)); + EXPECT_TRUE(msg.LeafValue().same_as(c0)); ++x_count; } else if (value.same_as(y)) { - EXPECT_TRUE(msg.same_as(c1)); + EXPECT_TRUE(msg.LeafValue().same_as(c1)); ++y_count; } else { - EXPECT_TRUE(msg.same_as(c2)); + EXPECT_TRUE(msg.LeafValue().same_as(c2)); ++z_count; } });