Skip to content

Commit 89c0235

Browse files
authored
[Relay] Plumb external codegen target via Target.current() (#11432)
* [Relay] Plumb external codegen target via Target.current() for all external codegen paths (See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md). We want both old-style (via relay.ext.$toolchain) and new-style (via "RelayToTIR" Pass attribute on target kind) external codegen to be able to access the current 'external codegen' Target instance via Target.current(). - For old-style, plumb the true Target through TEComplier and push it on the context stack before calling relay.ext.$toolchain. - For new-style, pass the CompilationConfig to the RelayToTIRTargetHook pass, make the jump from "Compiler" attribute value to Target via the new CompilationConfig::FindPrimitiveTargetForKind method, and push on the stack before invoking the custom "RelayToTIR" pass. While working on this discovered RelayToTIRTargetHook was incompatible with the VM's compilation flow since RelayToTIRTargetHook assumes all "Compiler" attributed functions are inlined. Generalize it to support both inline and global function styles. Extend Target::IsExternalCodegen to recognize target kinds with "RelayToTIR" attributes as external. Update target hooks unit test to exercise new support for outline-style, picking up the current target, and compiling via the VM. * - A bit of polishing en passant. * - Add comment as per Josh's suggestion Can't repro tests/python/contrib/test_ethosu/cascader/test_scheduler.py::test_compute_cycles_annotation failure, flake?
1 parent 62e449c commit 89c0235

File tree

24 files changed

+512
-162
lines changed

24 files changed

+512
-162
lines changed

include/tvm/relay/transform.h

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,11 +462,50 @@ TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
462462
TVM_DLL Pass SimplifyExpr();
463463

464464
/*!
465-
* \brief Run any registered RelayToTIR passes registered on the functions in a module.
465+
* \brief Run any custom passes registered under "RelayToTIR" attributes on TargetKinds.
466+
*
467+
* This pass looks for inline, let-bound or global functions which have a "Compiler" attribute.
468+
* If the attribute value corresponds to a TargetKind with a "RelayToTIR" attribute, then the
469+
* 'custom' pass bound to that attribute is run (at most once) on the IRModule as a whole.
470+
*
471+
* If, in addition, the \p config has a Target with a matching TargetKind, that Target is set
472+
* as the 'current' target before the custom pass is executed. In this way it is possible
473+
* for custom passes to pick up target options which may guide how they transform the IRModule.
474+
* (Those targets are referred to as 'extern codegen targets' elsewhere).
475+
*
476+
* A typical custom pass will:
477+
* - Find calls to "Compiler" attributes functions with matching compiler name.
478+
* - Lower those function to TIR PrimFuncs.
479+
* - Bind those functions into the IRModule under the the functions' "global_symbol" attribute.
480+
* - Replace all calls to those functions with 'call_lowered' to the matching global.
481+
* Care should be taken to handle multiple calls to the same function.
482+
* See src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc for an example custom pass.
483+
*
484+
* It is also possible (despite the pass and attribute names!) for the custom pass to proceed
485+
* directly to a runtime::Module, which can be attached to the output IRModules "external_mods"
486+
* attribute (taking care not to clobber any existing modules). In this case the flow is as above,
487+
* except:
488+
* - The runtime::Module must contain a binding for each compiled function under their
489+
* "global_symbol" (ie runtime::Module::ImplementsFunction should return true).
490+
* - A Relay Function must be bound (or re-bound) into the result IRModule, again with the same
491+
* "global_symbol", but with only the "Extern" attribute set to Integer(1). The function body
492+
* should be the original function body. In this way we always have a TVM definition matching
493+
* every global function name.
494+
*
495+
* There are many existing runtime::Modules, ranging from source to object to dynamic libaries to
496+
* entirely custom implementations. Some of those may require additional compilation using
497+
* 'export_library' on the final build artifact.
498+
*
499+
* The OutlineCompilerFunctionsWithExistingGlobalSymbols and MarkCompilerFunctionsAsExtern utility
500+
* passes can be used by custom passes to take care of some of the boilerplate.
501+
*
502+
* TODO(mbs): Rename PreLoweringTargetHooks?
503+
*
504+
* \param config All available targets.
466505
*
467506
* \return The pass.
468507
*/
469-
TVM_DLL Pass RelayToTIRTargetHook();
508+
TVM_DLL Pass RelayToTIRTargetHook(CompilationConfig config);
470509

471510
/*!
472511
* \brief A pass for manifesting explicit memory allocations and rewriting

include/tvm/target/target_kind.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,16 @@ namespace attr {
402402
* See also \p Target::IsExternalCodegenFor
403403
*/
404404
constexpr const char* kIsExternalCodegen = "is_external_codegen";
405+
406+
/*!
407+
* \brief A \p TargetKind attribute of type \p FTVMRelayToTIR. If set, then the target kind name
408+
* also corresponds to an external codegen 'compiler' name, and the bound value is a \p Pass
409+
* to apply before the TVM lowering.
410+
*
411+
* See also \p Target::IsExternalCodegenFor
412+
*/
413+
constexpr const char* kRelayToTIR = "RelayToTIR";
414+
405415
} // namespace attr
406416

407417
/*!

src/relay/backend/aot_executor_codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10791079
// lowering process directly.
10801080
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
10811081
},
1082-
config_->host_virtual_device)(mod);
1082+
config_)(mod);
10831083

10841084
auto lowered_main = lowered_mod->Lookup("main");
10851085
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

src/relay/backend/contrib/cmsisnn/target.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ tvm::transform::Pass RelayToTIR();
3131
runtime::Module TIRToRuntime(IRModule mod, Target target);
3232

3333
TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
34-
.set_attr<FTVMRelayToTIR>("RelayToTIR", RelayToTIR())
34+
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
3535
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
3636

3737
} // namespace cmsisnn

src/relay/backend/contrib/codegen_c/codegen.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,14 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
227227
Array<String> variables = std::get<0>(res);
228228
String func_name = std::get<1>(res);
229229

230+
Optional<Target> opt_target = Target::Current();
231+
if (opt_target.defined() && opt_target.value()->kind->name == "ccompiler") {
232+
Optional<String> header = opt_target.value()->GetAttr<String>("header");
233+
if (header.defined() && !header.value().empty()) {
234+
code_stream_ << header.value().c_str() << "\n";
235+
}
236+
}
237+
230238
// Create headers
231239
code_stream_ << "#include <stdio.h>\n";
232240
code_stream_ << "#include <stdlib.h>\n";
@@ -293,6 +301,10 @@ runtime::Module CCompiler(const ObjectRef& ref) {
293301

294302
TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler);
295303

304+
TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU)
305+
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
306+
.add_attr_option<String>("header", String("")); // value is prepended to every output CModule
307+
296308
} // namespace contrib
297309
} // namespace relay
298310
} // namespace tvm

src/relay/backend/contrib/ethosu/codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
320320

321321
TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU)
322322
.set_attr<Bool>("use_device_api", Bool(true))
323-
.set_attr<FTVMRelayToTIR>("RelayToTIR", RelayToTIR())
323+
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
324324
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
325325

326326
} // namespace ethosu

src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc

Lines changed: 144 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,37 @@
2828
#include <tvm/tir/op.h>
2929

3030
#include "../../../op/call/call.h"
31+
#include "tvm/tir/function.h"
3132

3233
namespace tvm {
3334
namespace relay {
3435
namespace contrib {
3536
namespace example_target_hooks {
3637

38+
namespace {
39+
40+
/*!
41+
* \brief An example mutator for a "RelayToTIR" custom pass. Replaces every call to a Relay
42+
* Function with "external_symbol" attribute of "replace_add_with_subtract" with a call to a
43+
* TIR PrimFunc implementing subtraction.
44+
*
45+
* Illustrates six aspects a custom 'lowering' style pass may need to account for:
46+
* - Lowerable functions can appear inline as call ops, bound to let-bound variables, or as
47+
* global functions.
48+
* - Let-bound lowerable functions should be inlined on-the-fly since after processing the
49+
* let-binding is no longer required.
50+
* - There may be multiple calls to the same lowerable function. All calls need to be
51+
* rewritten, even though the function itself need be rewritten only once.
52+
* - GlobalVars must be shared between all calls and the new definition itself.
53+
* - Calls to lowered functions must use the "call_lowered" calling convention.
54+
* - The Target::Current() may hold an instance of the TargetKind from which the custom Pass
55+
* was extracted.
56+
*
57+
* Though not illustrated here, it is also valid for a "RelayToTIR" custom pass to add
58+
* runtime::Modules to the output IRModule's "external_mods" attribute. In this case the
59+
* IRModule must be left with an 'extern' Function definition with the matching "external_symbol"
60+
* name.
61+
*/
3762
class ConvertAddToSubtract : public MixedModeMutator {
3863
public:
3964
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
@@ -56,51 +81,102 @@ class ConvertAddToSubtract : public MixedModeMutator {
5681
return tir::BufferLoad(buffer, {index});
5782
}
5883

59-
void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) {
60-
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x");
61-
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y");
62-
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32));
84+
GlobalVar ReplaceAddWithSubtractPrimFunc(const Function& func) {
85+
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
86+
ICHECK(func_name.defined());
6387

64-
tir::Var x_var("x", DataType::Handle());
65-
tir::Var y_var("y", DataType::Handle());
66-
tir::Var out_var("out", DataType::Handle());
88+
// --------------------------------------------------------------------------------------------
89+
// Cases:
90+
// - Inline function:
91+
// - First encounter: create global var, rewrite to PrimFunc, add binding, replace call.
92+
// - Thereafter (via object sharing): discover global var already in module, replace call
93+
// - Global function:
94+
// - Assume func_name == global_var->name_hint
95+
// - First encounter: create global var, rewrite to PrimFunc, update binding, replace call
96+
// - Thereafter (via global var): discover global var already in module, replace call
97+
// --------------------------------------------------------------------------------------------
6798

68-
Map<String, ObjectRef> dict_attrs;
69-
dict_attrs.Set("global_symbol", new_global_var->name_hint);
70-
dict_attrs.Set("tir.noalias", Bool(true));
99+
// If necessary, introduce a new global var to map the function to and copy the source type
100+
// over for InferType.
101+
GlobalVar global_var;
102+
bool need_rewriting;
103+
if (ir_module_->ContainGlobalVar(func_name.value())) {
104+
global_var = ir_module_->GetGlobalVar(func_name.value());
105+
// Only rewrite to a PrimFunc if the global definition is still a Relay function.
106+
need_rewriting = ir_module_->Lookup(global_var)->IsInstance<FunctionNode>();
107+
} else {
108+
global_var = GlobalVar(func_name.value());
109+
global_var->checked_type_ = func->checked_type();
110+
need_rewriting = true;
111+
}
71112

72-
te::Var index("index", DataType::Int(32));
73-
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
74-
tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index});
75-
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);
113+
// For illustration only, check if the current target matches the example_target_hook kind,
114+
// and if so extract the example attribute value.
115+
int64_t example_attribute_value = 0;
116+
Optional<Target> opt_current_target = Target::Current();
117+
if (opt_current_target.defined() &&
118+
opt_current_target.value()->kind->name == "example_target_hook") {
119+
example_attribute_value =
120+
opt_current_target.value()->GetAttr<Integer>("example_attribute").value()->value;
121+
}
76122

77-
Map<tir::Var, tir::Buffer> buffer_map = {
78-
{x_var, x_buffer},
79-
{y_var, y_buffer},
80-
{out_var, out_buffer},
81-
};
123+
if (need_rewriting) {
124+
// The called function is still in Relay form. Convert to TIR.
125+
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x");
126+
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y");
127+
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32));
82128

83-
tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
84-
buffer_map, {}, DictAttrs(dict_attrs));
129+
tir::Var x_var("x", DataType::Handle());
130+
tir::Var y_var("y", DataType::Handle());
131+
tir::Var out_var("out", DataType::Handle());
85132

86-
// Switch to TIRToRuntime hook for testing
87-
Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
88-
if (tir_to_runtime) {
89-
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_);
90-
} else {
91-
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
133+
Map<String, ObjectRef> dict_attrs;
134+
dict_attrs.Set("global_symbol", global_var->name_hint);
135+
dict_attrs.Set("tir.noalias", Bool(true));
136+
137+
te::Var index("index", DataType::Int(32));
138+
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
139+
if (example_attribute_value > 0) {
140+
// For illustration only, fold the example attribute into the result.
141+
indexed_sub = tir::Sub(indexed_sub, FloatImm(DataType::Float(32),
142+
static_cast<double>(example_attribute_value)));
143+
}
144+
145+
tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index});
146+
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);
147+
148+
Map<tir::Var, tir::Buffer> buffer_map = {
149+
{x_var, x_buffer},
150+
{y_var, y_buffer},
151+
{out_var, out_buffer},
152+
};
153+
154+
tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
155+
buffer_map, {}, DictAttrs(dict_attrs));
156+
157+
// Switch to TIRToRuntime hook for testing
158+
Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
159+
if (tir_to_runtime) {
160+
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_);
161+
} else {
162+
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
163+
}
164+
165+
ir_module_->Update(global_var, replacement_func); // Will Add if global_var is new.
92166
}
93167

94-
ir_module_->Add(new_global_var, replacement_func);
168+
return global_var;
95169
}
96170

171+
using MixedModeMutator::VisitExpr_;
172+
97173
Expr VisitExpr_(const LetNode* op) final {
98174
auto pre_visit = [this](const LetNode* op) {
99175
Expr var = this->VisitExpr(op->var);
100176
Expr value = this->VisitExpr(op->value);
101177

102-
// Outlineable function no longer needs let binding
103-
if (this->CanLowerExpr(value)) {
178+
if (AsLowerableFunction(value)) {
179+
// Inline on-the-fly if the let-bound value is lowerable.
104180
this->memo_[var] = value;
105181
}
106182
};
@@ -110,8 +186,8 @@ class ConvertAddToSubtract : public MixedModeMutator {
110186
Expr body = this->VisitExpr(op->body);
111187
auto expr = GetRef<Expr>(op);
112188

113-
// Drop the let binding
114-
if (this->CanLowerExpr(value)) {
189+
if (AsLowerableFunction(value)) {
190+
// The let binding is no longer needed since inlined on-the-fly above.
115191
this->memo_[expr] = this->VisitExpr(op->body);
116192
} else {
117193
Var var = Downcast<Var>(this->VisitExpr(op->var));
@@ -126,39 +202,49 @@ class ConvertAddToSubtract : public MixedModeMutator {
126202
return memo_[GetRef<Expr>(op)];
127203
}
128204

129-
bool CanLowerExpr(const Expr& expr) {
130-
const auto* func = expr.as<FunctionNode>();
131-
if (func == nullptr) {
132-
return false;
133-
}
134-
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
135-
if (!func_name.defined()) {
136-
return false;
205+
const FunctionNode* AsLowerableFunction(const Expr& expr) {
206+
if (const auto* function_node = expr.as<FunctionNode>()) {
207+
auto func_name = function_node->GetAttr<String>(::tvm::attr::kGlobalSymbol);
208+
if (!func_name.defined()) {
209+
return nullptr;
210+
}
211+
if (func_name != "replace_add_with_subtract") {
212+
return nullptr;
213+
}
214+
return function_node;
215+
} else if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
216+
return AsLowerableFunction(ir_module_->Lookup(GetRef<GlobalVar>(global_var_node)));
217+
} else {
218+
return nullptr;
137219
}
138-
if (func_name != "replace_add_with_subtract") {
139-
return false;
220+
}
221+
222+
const GlobalVarNode* AsAlreadyLoweredFunction(const Expr& expr) {
223+
if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
224+
if (ir_module_->Lookup(GetRef<GlobalVar>(global_var_node)).as<tir::PrimFuncNode>()) {
225+
return global_var_node;
226+
}
140227
}
141-
return true;
228+
return nullptr;
142229
}
143230

144231
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
145-
if (const CallNode* call = post.as<CallNode>()) {
146-
if (CanLowerExpr(call->op)) {
147-
auto* func = call->op.as<FunctionNode>();
148-
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
149-
150-
// Introduce a new global var to map the function to and copy the source type
151-
// over for InferType
152-
GlobalVar new_global_var(func_name.value());
153-
new_global_var->checked_type_ = func->checked_type();
154-
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
155-
232+
if (const auto* call = post.as<CallNode>()) {
233+
GlobalVar new_op;
234+
if (const auto* function_node = AsLowerableFunction(call->op)) {
235+
// Add or replace the function with a PrimFunc.
236+
new_op = ReplaceAddWithSubtractPrimFunc(GetRef<Function>(function_node));
237+
} else if (const auto* global_var_node = AsAlreadyLoweredFunction(call->op)) {
238+
// The function has already been rewritten, so we just need to update the call.
239+
new_op = GetRef<GlobalVar>(global_var_node);
240+
}
241+
if (new_op.defined()) {
156242
// Since we are replacing the Relay function with a call to a TIR function, we must use
157243
// the call_lowered op.
158244
CallLoweredAttrs attrs;
159245
attrs.metadata.Set("relay_attrs", call->attrs);
160246
ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic";
161-
return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span);
247+
return CallLowered(std::move(new_op), call->args, std::move(attrs), call->span);
162248
}
163249
}
164250

@@ -171,10 +257,12 @@ class ConvertAddToSubtract : public MixedModeMutator {
171257
Target custom_target_;
172258
};
173259

260+
} // namespace
261+
174262
transform::Pass RelayToTIR() {
175263
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
176264
[=](IRModule ir_module, transform::PassContext pass_context) {
177-
auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c"));
265+
ConvertAddToSubtract relay_to_tir(std::move(ir_module), Target("c"));
178266
return relay_to_tir.Mutate();
179267
};
180268
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {});

0 commit comments

Comments
 (0)