diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index a2a3e96dd567..feab261e3076 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -38,9 +38,11 @@ #include #include +#include #include "../../relay/analysis/graph_partitioner.h" #include "../../support/arena.h" +#include "../../support/ordered_set.h" #include "tvm/relax/expr.h" #include "utils.h" @@ -360,6 +362,169 @@ class GraphCreator : public ExprVisitor { std::unordered_set initialized_nodes_; }; +class InferredCommonSubexpressionCollector : relax::ExprVisitor, + StructInfoVisitor, + tir::ExprVisitor { + public: + struct InferResult { + // A list of additional symbolic variables that must be provided + // to the function. These variables cannot be inferred from the + // StructInfo of the existing parameters. + Array symbolic_vars; + + // A list of expressions, each of which must be remapped to a new + // symbolic variable. These expressions can be inferred from the + // StructInfo of the existing parameters, but may contain + // sub-expressions that cannot. + Array symbolic_expressions; + }; + + static InferResult Infer(Array params, Expr body) { + InferredCommonSubexpressionCollector collector; + collector.VisitStructInfo(TupleStructInfo(params.Map(GetStructInfo))); + collector.phase_ = Phase::CollectRequiredExpressions; + collector.VisitExpr(body); + + return InferResult{ + Array(collector.required_symbolic_vars_.begin(), + collector.required_symbolic_vars_.end()), + Array(collector.required_symbolic_exprs_.begin(), + collector.required_symbolic_exprs_.end()), + }; + } + + private: + using relax::ExprVisitor::VisitExpr; + using relax::ExprVisitor::VisitExpr_; + using tir::ExprVisitor::VisitExpr; + using tir::ExprVisitor::VisitExpr_; + + void VisitExprDepStructInfoField(const StructInfo& struct_info) override { + VisitStructInfo(struct_info); + } + void VisitStructInfoExprField(const Expr& expr) override { VisitStructInfo(GetStructInfo(expr)); } + void VisitStructInfoExprField(const PrimExpr& expr) override { + if (expr->IsInstance()) { + return; + } + + switch (phase_) { + case Phase::CollectInferableExpressions: + inferable_expressions_.insert(expr); + break; + + case Phase::CollectRequiredExpressions: + VisitExpr(expr); + break; + + default: + LOG(FATAL) << "Invalid value for Phase: " << static_cast(phase_); + break; + } + } + + void VisitExpr(const PrimExpr& expr) override { + if (inferable_expressions_.count(expr)) { + required_symbolic_exprs_.insert(expr); + } else { + tir::ExprVisitor::VisitExpr(expr); + } + } + + void VisitExpr_(const tir::VarNode* op) override { + required_symbolic_vars_.push_back(GetRef(op)); + } + + enum class Phase { + CollectInferableExpressions, + CollectRequiredExpressions, + }; + Phase phase_ = Phase::CollectInferableExpressions; + std::unordered_set inferable_expressions_; + support::OrderedSet required_symbolic_vars_; + support::OrderedSet required_symbolic_exprs_; +}; + +/* \brief Replace occurrences of a PrimExpr in the symbolic variables + * + * In most cases, the `tvm::relax::Bind` utility should be used + * instead. Here, though, we are replacing a `PrimExpr` with a + * `tir::Var`, whereas `tvm::relax::Bind` supports the more standard + * case of replacing a `tir::Var` with a `PrimExpr`. + */ +class SymbolicSubexprReplacer : relax::ExprMutator, StructInfoMutator, tir::ExprMutator { + public: + /* \brief Replace occurrences of a PrimExpr in the symbolic variables + * + * In most cases, the `tvm::relax::Bind` utility should be used + * instead. Here, though, we are replacing a `PrimExpr` with a + * `tir::Var`, rather than the other way around. + * + * \param relax_expr The expression in which to replace symbolic expressions + * + * \param symbolic_exprs A list of expressions, each of which should + * be replaced with a new symbolic variable. This is provided as a + * list, rather than as a replacement map, to allow context-dependent + * names to be generated for these expressions. + * + * \returns The updated relax expression. + */ + static Expr Replace(const Expr& relax_expr, Array symbolic_exprs) { + std::unordered_map, StructuralHash, StructuralEqual> replacements; + for (const auto& expr : symbolic_exprs) { + replacements.insert({expr, NullOpt}); + } + + SymbolicSubexprReplacer mutator(replacements); + return mutator(relax_expr); + } + + private: + using relax::ExprMutator::operator(); + using relax::ExprMutator::VisitExpr; + using tir::ExprMutator::operator(); + using tir::ExprMutator::VisitExpr; + + SymbolicSubexprReplacer( + std::unordered_map, StructuralHash, StructuralEqual> + replacements) + : replacements_(replacements) {} + + StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { + return VisitStructInfo(struct_info); + } + Expr VisitStructInfoExprField(const Expr& expr) override { return VisitExpr(expr); } + PrimExpr VisitStructInfoExprField(const PrimExpr& expr) override { return VisitExpr(expr); } + PrimExpr VisitPrimExpr(const PrimExpr& expr) override { return VisitExpr(expr); } + + PrimExpr VisitExpr(const PrimExpr& expr) override { + if (auto replacement = GetReplacement(expr)) { + return replacement.value(); + } else { + return tir::ExprMutator::VisitExpr(expr); + } + } + + Optional GetReplacement(const PrimExpr& expr) { + auto it = replacements_.find(expr); + if (it == replacements_.end()) { + return NullOpt; + } + + Optional& opt_var = it->second; + if (!opt_var.defined()) { + // Ideally, this path would never be reached, as it doesn't + // provide as much context in the variable name. However, it's + // useful as a fallback. + opt_var = tir::Var("fused_expr", expr->dtype); + } + + return opt_var.value(); + } + + std::unordered_map, StructuralHash, StructuralEqual> replacements_; +}; + /*! * \brief The ExprMutator used to create a new grouped function * \details The workflow of this ExprMutator is: @@ -533,25 +698,44 @@ class FunctionCreator : public ExprMutator { function_ = NullOpt; } else { Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs); + body = SeqExpr({new_block}, body); body = builder_->Normalize(body); - body = builder_->Normalize(SeqExpr({new_block}, body)); + + // Any symbolic variables that are required within the body of + // the function, but cannot be inferred from the parameters of + // the function, must be exposed using an additional argument. + auto [symbolic_vars, symbolic_expressions] = + InferredCommonSubexpressionCollector::Infer(params_, body); + if (symbolic_vars.size()) { + auto symbolic_vars_as_expr = + symbolic_vars.Map([](tir::Var var) -> PrimExpr { return var; }); + params_.push_back(Var("tir_vars", ShapeStructInfo(symbolic_vars_as_expr))); + arguments_.push_back(ShapeExpr(symbolic_vars_as_expr)); + } + group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); Function function = Function(/*params=*/params_, // /*body=*/body, // /*ret_struct_info=*/NullOpt, // /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); - Array free_vars = - FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); - if (!free_vars.empty()) { - params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); - arguments_.push_back(ShapeExpr(free_vars)); - function = Function(/*params=*/params_, // - /*body=*/body, // - /*ret_struct_info=*/NullOpt, // - /*is_pure=*/true, // - /*attrs=*/DictAttrs(group_attrs)); + + // If the function contains symbolic expressions that can be + // inferred from the parameters, but contain subexpressions that + // cannot be inferred from the parameters, those expressions + // should be replaced with symbolic variables. + // + // For example, suppose a fused function maps from a tensor of + // shape `[batch_size+1, 1024]` to `[batch_size+1,1024]`. It + // cannot infer `batch_size`, but could infer the value of + // `batch_size+1`. By introducing `batch_size_plus_one = + // batch_size+1`, we can rely on just the infer-able symbolic + // vars. + if (symbolic_expressions.size()) { + function = + Downcast(SymbolicSubexprReplacer::Replace(function, symbolic_expressions)); } + function_ = SymbolicVarRenewMutator::Renew(function); } } diff --git a/src/support/ordered_set.h b/src/support/ordered_set.h index 741f0b18e6b9..50536f2310e6 100644 --- a/src/support/ordered_set.h +++ b/src/support/ordered_set.h @@ -26,6 +26,7 @@ #include +#include #include #include @@ -39,17 +40,31 @@ namespace detail { */ template struct OrderedSetLookupType { - using MapType = std::unordered_map::iterator>; + using Hash = std::hash; + using Equal = std::equal_to; }; template struct OrderedSetLookupType>> { - using MapType = std::unordered_map::iterator, runtime::ObjectPtrHash, - runtime::ObjectPtrEqual>; + using Hash = runtime::ObjectPtrHash; + using Equal = runtime::ObjectPtrEqual; }; } // namespace detail -template +/* \brief Utility to hold an ordered set + * + * \tparam T The type held by the OrderedSet + * + * \tparam LookupHash The hash implementation to use for detecting + * duplicate entries. If unspecified, defaults to `ObjectPtrHash` for + * TVM types, and `std::hash` otherwise. + * + * \tparam LookupEqual The equality-checker to use for detecting + * duplicate entries. If unspecified, defaults to `ObjectPtrEqual` + * for TVM types, and `std::equal_to` otherwise. + */ +template ::Hash, + typename LookupEqual = typename detail::OrderedSetLookupType::Equal> class OrderedSet { public: OrderedSet() = default; @@ -91,7 +106,7 @@ class OrderedSet { private: std::list elements_; - typename detail::OrderedSetLookupType::MapType elem_to_iter_; + std::unordered_map::iterator, LookupHash, LookupEqual> elem_to_iter_; }; } // namespace support diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 5e700b277f32..b331570d622f 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -1217,5 +1217,82 @@ def inner_func( tvm.ir.assert_structural_equal(Expected, After) +def test_matmul_symbolic_expr(): + """Like `test_matmul_symbolic_var`, but with a PrimExpr shape + + The shape of weights used in the matmul are `[1024, M + 1024]`, + which can result from `CombineParallelMatmul`. If the fused + function is written in terms of `M`, then `M` must be provided as + an additional `ShapeExpr`, as it cannot be inferred from the + tensor shape. This can cause issues for downstream passes, as + CodeGenJSON, used by the TVM's runtime for cublas and cutlass, + only supports `R.Tensor` and tuples of `R.Tensor`. + + If a symbolic variable is only used within expressions that + themselves are inferable from the tensor shapes, then the fused + function could be written in terms of that expression, removing + the need for the `ShapeExpr`. Here, the expression `M + 1024` is + replaced by the variable `w2_size`. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1024], dtype="float16"), + w1: R.Tensor([1024, 1024], dtype="float16"), + w2: R.Tensor([1024, "M"], dtype="float16"), + ) -> R.Tensor(["batch_size", "M + 1024"], "float16"): + with R.dataflow(): + concat = R.concat([w1, w2], axis=1) + out = R.matmul(x, concat) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 1024], dtype="float16"), + w1: R.Tensor([1024, 1024], dtype="float16"), + w2: R.Tensor([1024, "M"], dtype="float16"), + ) -> R.Tensor(["batch_size", "M + 1024"], "float16"): + cls = Expected + with R.dataflow(): + concat = R.concat([w1, w2], axis=1) + out = cls.fused_relax_matmul_cublas(x, concat) + R.output(out) + return out + + @R.function + def fused_relax_matmul_cublas( + x: R.Tensor(["batch_size", 1024], dtype="float16"), + w2: R.Tensor([1024, "w2_size"], dtype="float16"), + ) -> R.Tensor(["batch_size", "w2_size"], dtype="float16"): + batch_size = T.int64() + w2_size = T.int64() + R.func_attr({"Codegen": "cublas"}) + + @R.function + def inner_func( + x: R.Tensor([batch_size, 1024], dtype="float16"), + w2: R.Tensor((1024, w2_size), dtype="float16"), + ) -> R.Tensor([batch_size, w2_size], dtype="float16"): + R.func_attr({"Composite": "cublas.matmul"}) + with R.dataflow(): + out = R.matmul(x, w2) + R.output(out) + return out + + out = inner_func(x, w2) + return out + + patterns = relax.backend.pattern_registry.get_patterns_with_prefix("cublas.matmul") + After = relax.transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)( + Before + ) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": pytest.main([__file__])