Skip to content

Commit f88b5a8

Browse files
committed
- Andrew's comments
1 parent 123fcc9 commit f88b5a8

File tree

4 files changed

+33
-12
lines changed

4 files changed

+33
-12
lines changed

include/tvm/relay/function.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,23 +172,29 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func);
172172
namespace attr {
173173

174174
/*!
175-
* \brief Mark the function as a primitive function. Should be bound to \p Integer(1).
175+
* \brief Mark the function as representing a sub-graph which is to be lowered or compiled as
176+
* a unit. For example, the function may represent a kernel which TVM will lower to a PrimFunc.
177+
* If present should be bound to \p Integer(1). May be accompanied by "Compiler", see below.
178+
* The function body should be considered opaque by Relay, and many passes simply ignore these
179+
* functions.
176180
*
177181
* Type: Integer
178182
*/
179183
constexpr const char* kPrimitive = "Primitive";
180184

181185
/*!
182-
* \brief Mark the function as being 'extern', ie implemented in a runtime::Module. Should be bound
183-
* to \p Integer(1). Typically accompanied by "Primitive".
186+
* \brief Mark the function as externally implemented, ie bound in a runtime::Module within the
187+
* IRModule's "external_mods" attribute. If present should be bound to \p Integer(1). Generally
188+
* the only attribute when present.
184189
*
185190
* Type: Integer
186191
*/
187192
constexpr const char* kExtern = "Extern";
188193

189194
/*!
190-
* \brief Indicate the external codegen 'compiler' that should be used for building this function.
191-
* When this is unset or set to "default", the default compilation pipeline will be used.
195+
* \brief Indicates the name of the external codegen 'compiler' that should be used to lower
196+
* or compile the function other than TVM's default lowering pipeline. The name may correspond
197+
* to a TargetKind name. There may be a global function registered under 'relay.ext.{name}'.
192198
*
193199
* Type: String
194200
*/

python/tvm/relay/transform/transform.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,8 +1371,10 @@ def SplitArgs(max_function_args):
13711371

13721372

13731373
def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""):
1374-
"""A pass to outline all literal functions in direct call positions which have a "Compiler"
1375-
attribute. The functions are bound to unique global vars according to their existing
1374+
"""Outlines all literal functions in direct call positions which have a "Compiler"
1375+
attribute.
1376+
1377+
The outlined functions are bound to unique global vars according to their existing
13761378
"global_symbol" attribute. At most one function with the same global symbol is outlined.
13771379
13781380
If compiler_filter is non-empty only functions with that as their attribute value are
@@ -1395,9 +1397,11 @@ def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""):
13951397

13961398

13971399
def MarkCompilerFunctionsAsExtern(compiler_filter=""):
1398-
"""A pass to mark all global functions which have a "Compiler" attribute matching
1399-
compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and
1400-
rewrite all calls to such functions to use the 'call_lowered' calling convention.
1400+
"""Marks all global functions which have a "Compiler" attribute matching
1401+
compiler_filter as 'extern'.
1402+
1403+
The function's attributes are replaced with a single "Extern" attribute, and
1404+
all calls to the function are switched to use the 'call_lowered' calling convention.
14011405
14021406
If compiler_filter is non-empty only functions with that as their attribute value are
14031407
outlined.

src/relay/backend/te_compiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class TECompilerImpl : public TECompilerNode {
176176
WithFields(GetRef<Function>(function_node), function_node->params,
177177
function_node->body, function_node->ret_type, function_node->type_params,
178178
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
179-
// Mark function as 'extern' using the "ExternalSymbol" attribute.
179+
// Mark function as 'extern'.
180180
function = WithAttr(std::move(function), attr::kExtern, Integer(1));
181181
module->Add(kv2.first, function);
182182
}

src/relay/transforms/compiler_function_utils.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ class Outliner : public MixedModeMutator {
4949
if (opt_compiler.defined() &&
5050
(compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) {
5151
auto function = GetRef<Function>(function_node);
52-
ICHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler
52+
DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler
5353
<< "' attribute should not have free variables";
54+
// Ask the cache to supply a unique global var for this function.
5455
GlobalVar global_symbol = cache_->GetGlobalSymbol(function);
5556
// Depending on the cache's implementation, two structurally equal (but not object equal)
5657
// functions may be assigned the same global symbol. If so we'll lift it just once, but
@@ -60,15 +61,23 @@ class Outliner : public MixedModeMutator {
6061
WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint);
6162
mod_->Add(global_symbol, function);
6263
}
64+
// Update the call.
6365
return WithFields(new_call, global_symbol);
6466
}
6567
}
6668
return post;
6769
}
6870

6971
private:
72+
/*!
73+
* \brief A cached mapping from functions to global variables. Depending on the implementation
74+
* the cache may generate fresh symbols or require the function to already have a "global_symbol"
75+
* attribute, and may share symbols between structurally equal functions.
76+
*/
7077
GlobalSymbolCache* cache_;
78+
/*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */
7179
std::string compiler_filter_;
80+
/*! \brief Module being rewritten. */
7281
IRModule mod_;
7382
};
7483

@@ -102,7 +111,9 @@ class CallRewriter : public MixedModeMutator {
102111
}
103112

104113
private:
114+
/*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */
105115
std::string compiler_filter_;
116+
/*! \brief Module being rewritten. */
106117
IRModule mod_;
107118
};
108119

0 commit comments

Comments
 (0)