Skip to content

Commit d3d7213

Browse files
committed
[Relax][Transform] Compose preproc functions in LiftTransformParams
The `LiftTransformParams` pass produces additional functions, either named `$FOO_transform_params` when generating one transformation function per inference function, or `transform_params` when generating a single shared transformation function. Prior to this commit, if the `IRModule` already contained a function with that name, an error would be raised. After this commit, the `LiftTransformParams` pass will instead check for existing functions, and compose the previous transformation function with the newly-lifted transformation. This allows `LiftTransformParams` to be used alongside a hand-written parameter transformation. Closes #17200
1 parent 98de9ba commit d3d7213

File tree

4 files changed

+187
-52
lines changed

4 files changed

+187
-52
lines changed

src/relax/transform/lift_transform_params.cc

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,10 @@ struct BaseCollectInfo {
119119
Function func(params, body, GetStructInfo(tuple_var));
120120
func = WithAttr(func, attr::kNumInput, Integer(0));
121121
func = CopyWithNewVars(func);
122+
func = BundleModelParams(func);
122123
func = Downcast<Function>(CanonicalizeBindings(func));
124+
func = Downcast<Function>(RemoveAllUnused(func));
125+
123126
return func;
124127
}
125128
};
@@ -725,11 +728,12 @@ std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
725728
target_functions.push_back({gvar.value(), func.value()});
726729
}
727730
} else {
728-
// Get all the functions that have the `num_input` attribute.
731+
// Get all the functions that have the `num_input` attribute, and
732+
// are not already the result of `LiftTransformParams`.
729733
for (const auto& [gvar, func] : mod->functions) {
730734
if (func->IsInstance<FunctionNode>()) {
731735
auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput);
732-
if (opt_num_input) {
736+
if (opt_num_input && !ends_with(gvar->name_hint, "transform_params")) {
733737
target_functions.emplace_back(gvar, Downcast<Function>(func));
734738
}
735739
}
@@ -748,7 +752,6 @@ namespace transform {
748752

749753
Pass PartitionTransformParams(Variant<Bool, Array<String>> shared_transform) {
750754
auto pass_func = [=](IRModule mod, PassContext pc) {
751-
IRModule updates;
752755
std::optional<GlobalCollectInfo> global_collect_info;
753756

754757
CHECK(shared_transform.defined()) << "shared_transform is not defined";
@@ -772,24 +775,41 @@ Pass PartitionTransformParams(Variant<Bool, Array<String>> shared_transform) {
772775
local_collect_info[gvar] = info;
773776
}
774777

778+
IRModule updated_runtime_functions;
779+
775780
for (const auto& [gvar, info] : local_collect_info) {
776781
auto new_runtime_func = info.MakeRuntimeFunction();
777-
updates->Add(gvar, new_runtime_func);
782+
updated_runtime_functions->Add(gvar, new_runtime_func);
778783
}
779784

785+
Map<String, Function> lifted_transform_functions;
780786
if (global_collect_info.has_value()) {
781787
auto global_transform = global_collect_info.value().MakeCompileTimeFunc();
782-
updates->Add(GlobalVar("transform_params"), global_transform);
788+
lifted_transform_functions.Set("transform_params", global_transform);
783789
} else {
784790
for (const auto& [gvar, info] : local_collect_info) {
785791
// transform_params is emitted for each function if global lifting is not enabled
786-
updates->Add(GlobalVar(gvar->name_hint + "_transform_params"),
787-
info.MakeCompileTimeFunction());
792+
lifted_transform_functions.Set(gvar->name_hint + "_transform_params",
793+
info.MakeCompileTimeFunction());
788794
}
789795
}
790796

791-
if (updates->functions.size()) {
792-
mod.CopyOnWrite()->Update(updates);
797+
if (updated_runtime_functions->functions.size() || lifted_transform_functions.size()) {
798+
auto write_ptr = mod.CopyOnWrite();
799+
write_ptr->Update(updated_runtime_functions);
800+
801+
for (auto [name, transform] : lifted_transform_functions) {
802+
if (auto opt = write_ptr->global_var_map_.Get(name)) {
803+
auto old_gvar = opt.value();
804+
auto old_transform = Downcast<Function>(write_ptr->Lookup(old_gvar));
805+
write_ptr->Remove(old_gvar);
806+
807+
transform = ComposeFunctions(old_transform, transform);
808+
}
809+
GlobalVar new_gvar(name);
810+
UpdateStructInfo(new_gvar, GetStructInfo(transform));
811+
write_ptr->Add(new_gvar, transform);
812+
}
793813
}
794814

795815
return mod;
@@ -817,7 +837,6 @@ Pass LiftTransformParams(Variant<Bool, Array<String>> shared_transform) {
817837
std::string func_name = gvar->name_hint;
818838
if (ends_with(func_name, "transform_params")) {
819839
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
820-
func = BundleModelParams(func);
821840
if (pc->GetConfig<Bool>(kLiftTransformConsumeParams).value_or(Bool(false))) {
822841
func = Downcast<Function>(ConsumeBundledParams()(func));
823842
}

src/relax/transform/utils.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
#include "utils.h"
2121

22+
#include <tvm/relax/analysis.h>
23+
2224
namespace tvm {
2325
namespace relax {
2426

@@ -41,5 +43,54 @@ bool IsNestedTensor(const StructInfo& sinfo) {
4143

4244
bool IsNestedTensor(const Expr& expr) { return IsNestedTensor(GetStructInfo(expr)); }
4345

46+
Function ComposeFunctions(Function func_a, Function func_b) {
47+
Array<Binding> bindings;
48+
49+
Var func_a_output("func_a_output", func_a->ret_struct_info);
50+
51+
bindings.push_back(VarBinding(func_a_output, func_a->body));
52+
53+
auto func_a_outputs = [&]() -> Array<Expr> {
54+
if (auto func_a_output_tuple = func_a->ret_struct_info.as<TupleStructInfoNode>()) {
55+
Array<Expr> outputs;
56+
for (size_t i = 0; i < func_a_output_tuple->fields.size(); i++) {
57+
outputs.push_back(TupleGetItem(func_a_output, i));
58+
}
59+
return outputs;
60+
} else {
61+
return {func_a_output};
62+
}
63+
}();
64+
65+
if (func_b->params.size() == 1 && func_b->params[0]->struct_info_.as<TupleStructInfoNode>()) {
66+
// Special case where the output of the first function is a tuple
67+
// that should be provided as-is to the second function, and
68+
// should not be unpacked into individual elements.
69+
auto param = func_b->params[0];
70+
bindings.push_back(MatchCast(param, func_a_output, GetStructInfo(param)));
71+
} else {
72+
CHECK_EQ(func_a_outputs.size(), func_b->params.size())
73+
<< "ValueError: "
74+
<< "Cannot compose functions together. "
75+
<< "First function produces " << func_a_outputs.size() << " values, "
76+
<< "but second function expects " << func_b->params.size() << " parameters as input";
77+
for (size_t i = 0; i < func_a_outputs.size(); i++) {
78+
auto param = func_b->params[i];
79+
bindings.push_back(MatchCast(param, func_a_outputs[i], GetStructInfo(param)));
80+
}
81+
}
82+
83+
auto new_body = SeqExpr({BindingBlock(bindings)}, func_b->body);
84+
85+
auto new_function = Function(func_a->params, new_body, func_b->ret_struct_info,
86+
func_a->is_pure && func_b->is_pure, func_a->attrs);
87+
88+
new_function = CopyWithNewVars(new_function);
89+
new_function = Downcast<Function>(CanonicalizeBindings(new_function));
90+
new_function = Downcast<Function>(RemoveAllUnused(new_function));
91+
92+
return new_function;
93+
}
94+
4495
} // namespace relax
4596
} // namespace tvm

src/relax/transform/utils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,20 @@ Expr CanonicalizeBindings(Expr expr);
437437
*/
438438
Function BundleModelParams(const Function& func, Optional<String> param_tuple_name = NullOpt);
439439

440+
/*! \brief Compose two functions
441+
*
442+
* Given two functions `func_a` and `func_b`, produce `func_c` such
443+
* that `func_c(x)` is equivalent to `func_b(func_a(x))`.
444+
*
445+
* If the output if `func_a` is not usable as the input of `func_b`,
446+
* an error will be raised.
447+
*
448+
* \param func_a The first function to be composed.
449+
* \param func_b The second function to be composed.
450+
* \return The composed function
451+
*/
452+
TVM_DLL Function ComposeFunctions(Function func_a, Function func_b);
453+
440454
} // namespace relax
441455
} // namespace tvm
442456

0 commit comments

Comments
 (0)