2828#include < tvm/tir/op.h>
2929
3030#include " ../../../op/call/call.h"
31+ #include " tvm/tir/function.h"
3132
3233namespace tvm {
3334namespace relay {
3435namespace contrib {
3536namespace example_target_hooks {
3637
38+ namespace {
39+
40+ /* !
41+ * \brief An example mutator for a "RelayToTIR" custom pass. Replaces every call to a Relay
42+ * Function with "external_symbol" attribute of "replace_add_with_subtract" with a call to a
43+ * TIR PrimFunc implementing subtraction.
44+ *
45+ * Illustrates six aspects a custom 'lowering' style pass may need to account for:
46+ * - Lowerable functions can appear inline as call ops, bound to let-bound variables, or as
47+ * global functions.
48+ * - Let-bound lowerable functions should be inlined on-the-fly since after processing the
49+ * let-binding is no longer required.
50+ * - There may be multiple calls to the same lowerable function. All calls need to be
51+ * rewritten, even though the function itself need be rewritten only once.
52+ * - GlobalVars must be shared between all calls and the new definition itself.
53+ * - Calls to lowered functions must use the "call_lowered" calling convention.
54+ * - The Target::Current() may hold an instance of the TargetKind from which the custom Pass
55+ * was extracted.
56+ *
57+ * Though not illustrated here, it is also valid for a "RelayToTIR" custom pass to add
58+ * runtime::Modules to the output IRModule's "external_mods" attribute. In this case the
59+ * IRModule must be left with an 'extern' Function definition with the matching "external_symbol"
60+ * name.
61+ */
3762class ConvertAddToSubtract : public MixedModeMutator {
3863 public:
3964 explicit ConvertAddToSubtract (IRModule ir_module, Target host_target)
@@ -56,51 +81,102 @@ class ConvertAddToSubtract : public MixedModeMutator {
5681 return tir::BufferLoad (buffer, {index});
5782 }
5883
59- void ReplaceAddWithSubtractPrimFunc (const GlobalVar& new_global_var, const Function& func) {
60- tir::Buffer x_buffer = tir::decl_buffer ({8 }, DataType::Float (32 ), " x" );
61- tir::Buffer y_buffer = tir::decl_buffer ({8 }, DataType::Float (32 ), " y" );
62- tir::Buffer out_buffer = tir::decl_buffer ({8 }, DataType::Float (32 ));
84+ GlobalVar ReplaceAddWithSubtractPrimFunc (const Function& func) {
85+ auto func_name = func->GetAttr <String>(::tvm::attr::kGlobalSymbol );
86+ ICHECK (func_name.defined ());
6387
64- tir::Var x_var (" x" , DataType::Handle ());
65- tir::Var y_var (" y" , DataType::Handle ());
66- tir::Var out_var (" out" , DataType::Handle ());
88+ // --------------------------------------------------------------------------------------------
89+ // Cases:
90+ // - Inline function:
91+ // - First encounter: create global var, rewrite to PrimFunc, add binding, replace call.
92+ // - Thereafter (via object sharing): discover global var already in module, replace call
93+ // - Global function:
94+ // - Assume func_name == global_var->name_hint
95+ // - First encounter: create global var, rewrite to PrimFunc, update binding, replace call
96+ // - Thereafter (via global var): discover global var already in module, replace call
97+ // --------------------------------------------------------------------------------------------
6798
68- Map<String, ObjectRef> dict_attrs;
69- dict_attrs.Set (" global_symbol" , new_global_var->name_hint );
70- dict_attrs.Set (" tir.noalias" , Bool (true ));
99+ // If necessary, introduce a new global var to map the function to and copy the source type
100+ // over for InferType.
101+ GlobalVar global_var;
102+ bool need_rewriting;
103+ if (ir_module_->ContainGlobalVar (func_name.value ())) {
104+ global_var = ir_module_->GetGlobalVar (func_name.value ());
105+ // Only rewrite to a PrimFunc if the global definition is still a Relay function.
106+ need_rewriting = ir_module_->Lookup (global_var)->IsInstance <FunctionNode>();
107+ } else {
108+ global_var = GlobalVar (func_name.value ());
109+ global_var->checked_type_ = func->checked_type ();
110+ need_rewriting = true ;
111+ }
71112
72- te::Var index (" index" , DataType::Int (32 ));
73- tir::Sub indexed_sub = tir::Sub (LoadIndex (x_buffer, index), LoadIndex (y_buffer, index));
74- tir::Stmt math_body = tir::BufferStore (out_buffer, indexed_sub, {index});
75- tir::Stmt math_loop = tir::For (index, 0 , 8 , tir::ForKind::kSerial , math_body);
113+ // For illustration only, check if the current target matches the example_target_hook kind,
114+ // and if so extract the example attribute value.
115+ int64_t example_attribute_value = 0 ;
116+ Optional<Target> opt_current_target = Target::Current ();
117+ if (opt_current_target.defined () &&
118+ opt_current_target.value ()->kind ->name == " example_target_hook" ) {
119+ example_attribute_value =
120+ opt_current_target.value ()->GetAttr <Integer>(" example_attribute" ).value ()->value ;
121+ }
76122
77- Map<tir::Var, tir::Buffer> buffer_map = {
78- {x_var, x_buffer},
79- {y_var, y_buffer},
80- {out_var, out_buffer},
81- } ;
123+ if (need_rewriting) {
124+ // The called function is still in Relay form. Convert to TIR.
125+ tir::Buffer x_buffer = tir::decl_buffer ({ 8 }, DataType::Float ( 32 ), " x " );
126+ tir::Buffer y_buffer = tir::decl_buffer ({ 8 }, DataType::Float ( 32 ), " y " );
127+ tir::Buffer out_buffer = tir::decl_buffer ({ 8 }, DataType::Float ( 32 )) ;
82128
83- tir::PrimFunc replacement_func = tir::PrimFunc ({x_var, y_var, out_var}, math_loop, VoidType (),
84- buffer_map, {}, DictAttrs (dict_attrs));
129+ tir::Var x_var (" x" , DataType::Handle ());
130+ tir::Var y_var (" y" , DataType::Handle ());
131+ tir::Var out_var (" out" , DataType::Handle ());
85132
86- // Switch to TIRToRuntime hook for testing
87- Bool tir_to_runtime = func->GetAttr <Bool>(" tir_to_runtime" ).value_or (Bool (false ));
88- if (tir_to_runtime) {
89- replacement_func = WithAttr (replacement_func, ::tvm::attr::kTarget , custom_target_);
90- } else {
91- replacement_func = WithAttr (replacement_func, ::tvm::attr::kTarget , host_target_);
133+ Map<String, ObjectRef> dict_attrs;
134+ dict_attrs.Set (" global_symbol" , global_var->name_hint );
135+ dict_attrs.Set (" tir.noalias" , Bool (true ));
136+
137+ te::Var index (" index" , DataType::Int (32 ));
138+ tir::Sub indexed_sub = tir::Sub (LoadIndex (x_buffer, index), LoadIndex (y_buffer, index));
139+ if (example_attribute_value > 0 ) {
140+ // For illustration only, fold the example attribute into the result.
141+ indexed_sub = tir::Sub (indexed_sub, FloatImm (DataType::Float (32 ),
142+ static_cast <double >(example_attribute_value)));
143+ }
144+
145+ tir::Stmt math_body = tir::BufferStore (out_buffer, indexed_sub, {index});
146+ tir::Stmt math_loop = tir::For (index, 0 , 8 , tir::ForKind::kSerial , math_body);
147+
148+ Map<tir::Var, tir::Buffer> buffer_map = {
149+ {x_var, x_buffer},
150+ {y_var, y_buffer},
151+ {out_var, out_buffer},
152+ };
153+
154+ tir::PrimFunc replacement_func = tir::PrimFunc ({x_var, y_var, out_var}, math_loop, VoidType (),
155+ buffer_map, {}, DictAttrs (dict_attrs));
156+
157+ // Switch to TIRToRuntime hook for testing
158+ Bool tir_to_runtime = func->GetAttr <Bool>(" tir_to_runtime" ).value_or (Bool (false ));
159+ if (tir_to_runtime) {
160+ replacement_func = WithAttr (replacement_func, ::tvm::attr::kTarget , custom_target_);
161+ } else {
162+ replacement_func = WithAttr (replacement_func, ::tvm::attr::kTarget , host_target_);
163+ }
164+
165+ ir_module_->Update (global_var, replacement_func); // Will Add if global_var is new.
92166 }
93167
94- ir_module_-> Add (new_global_var, replacement_func) ;
168+ return global_var ;
95169 }
96170
171+ using MixedModeMutator::VisitExpr_;
172+
97173 Expr VisitExpr_ (const LetNode* op) final {
98174 auto pre_visit = [this ](const LetNode* op) {
99175 Expr var = this ->VisitExpr (op->var );
100176 Expr value = this ->VisitExpr (op->value );
101177
102- // Outlineable function no longer needs let binding
103- if ( this -> CanLowerExpr ( value)) {
178+ if ( AsLowerableFunction (value)) {
179+ // Inline on-the-fly if the let-bound value is lowerable.
104180 this ->memo_ [var] = value;
105181 }
106182 };
@@ -110,8 +186,8 @@ class ConvertAddToSubtract : public MixedModeMutator {
110186 Expr body = this ->VisitExpr (op->body );
111187 auto expr = GetRef<Expr>(op);
112188
113- // Drop the let binding
114- if ( this -> CanLowerExpr (value)) {
189+ if ( AsLowerableFunction (value)) {
190+ // The let binding is no longer needed since inlined on-the-fly above.
115191 this ->memo_ [expr] = this ->VisitExpr (op->body );
116192 } else {
117193 Var var = Downcast<Var>(this ->VisitExpr (op->var ));
@@ -126,39 +202,49 @@ class ConvertAddToSubtract : public MixedModeMutator {
126202 return memo_[GetRef<Expr>(op)];
127203 }
128204
129- bool CanLowerExpr (const Expr& expr) {
130- const auto * func = expr.as <FunctionNode>();
131- if (func == nullptr ) {
132- return false ;
133- }
134- auto func_name = func->GetAttr <String>(::tvm::attr::kGlobalSymbol );
135- if (!func_name.defined ()) {
136- return false ;
205+ const FunctionNode* AsLowerableFunction (const Expr& expr) {
206+ if (const auto * function_node = expr.as <FunctionNode>()) {
207+ auto func_name = function_node->GetAttr <String>(::tvm::attr::kGlobalSymbol );
208+ if (!func_name.defined ()) {
209+ return nullptr ;
210+ }
211+ if (func_name != " replace_add_with_subtract" ) {
212+ return nullptr ;
213+ }
214+ return function_node;
215+ } else if (const auto * global_var_node = expr.as <GlobalVarNode>()) {
216+ return AsLowerableFunction (ir_module_->Lookup (GetRef<GlobalVar>(global_var_node)));
217+ } else {
218+ return nullptr ;
137219 }
138- if (func_name != " replace_add_with_subtract" ) {
139- return false ;
220+ }
221+
222+ const GlobalVarNode* AsAlreadyLoweredFunction (const Expr& expr) {
223+ if (const auto * global_var_node = expr.as <GlobalVarNode>()) {
224+ if (ir_module_->Lookup (GetRef<GlobalVar>(global_var_node)).as <tir::PrimFuncNode>()) {
225+ return global_var_node;
226+ }
140227 }
141- return true ;
228+ return nullptr ;
142229 }
143230
144231 Expr Rewrite_ (const CallNode* pre , const Expr& post ) override {
145- if (const CallNode* call = post .as <CallNode>()) {
146- if (CanLowerExpr (call->op )) {
147- auto * func = call->op .as <FunctionNode>();
148- auto func_name = func->GetAttr <String>(::tvm::attr::kGlobalSymbol );
149-
150- // Introduce a new global var to map the function to and copy the source type
151- // over for InferType
152- GlobalVar new_global_var (func_name.value ());
153- new_global_var->checked_type_ = func->checked_type ();
154- ReplaceAddWithSubtractPrimFunc (new_global_var, GetRef<Function>(func));
155-
232+ if (const auto * call = post .as <CallNode>()) {
233+ GlobalVar new_op;
234+ if (const auto * function_node = AsLowerableFunction (call->op )) {
235+ // Add or replace the function with a PrimFunc.
236+ new_op = ReplaceAddWithSubtractPrimFunc (GetRef<Function>(function_node));
237+ } else if (const auto * global_var_node = AsAlreadyLoweredFunction (call->op )) {
238+ // The function has already been rewritten, so we just need to update the call.
239+ new_op = GetRef<GlobalVar>(global_var_node);
240+ }
241+ if (new_op.defined ()) {
156242 // Since we are replacing the Relay function with a call to a TIR function, we must use
157243 // the call_lowered op.
158244 CallLoweredAttrs attrs;
159245 attrs.metadata .Set (" relay_attrs" , call->attrs );
160246 ICHECK (call->type_args .empty ()) << " lowered functions cannot be polymorphic" ;
161- return CallLowered (std::move (new_global_var ), call->args , std::move (attrs), call->span );
247+ return CallLowered (std::move (new_op ), call->args , std::move (attrs), call->span );
162248 }
163249 }
164250
@@ -171,10 +257,12 @@ class ConvertAddToSubtract : public MixedModeMutator {
171257 Target custom_target_;
172258};
173259
260+ } // namespace
261+
174262transform::Pass RelayToTIR () {
175263 runtime::TypedPackedFunc<IRModule (IRModule, transform::PassContext)> pass_func =
176264 [=](IRModule ir_module, transform::PassContext pass_context) {
177- auto relay_to_tir = ConvertAddToSubtract ( ir_module, Target (" c" ));
265+ ConvertAddToSubtract relay_to_tir ( std::move ( ir_module) , Target (" c" ));
178266 return relay_to_tir.Mutate ();
179267 };
180268 return tvm::transform::CreateModulePass (pass_func, 0 , " RelayToTIR" , {});
0 commit comments