Skip to content

Commit a3dc554

Browse files
author
Michalis Papapdimitriou
committed
Address PR comments
1 parent a4d95c7 commit a3dc554

File tree

2 files changed

+74
-30
lines changed

2 files changed

+74
-30
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def partition_for_tensorrt(
106106
max_workspace_size=1 << 30,
107107
use_fp16=False,
108108
use_uint8=False,
109+
use_patterns=False,
109110
):
110111
"""Partition the graph greedily offloading supported operators to TensorRT.
111112
@@ -136,6 +137,9 @@ def partition_for_tensorrt(
136137
lower runtime, or if no low-precision implementation exists.
137138
use_uint8: Optional[bool]
138139
Allows, TRT to automatically convert FP32 inputs to UINT8.
140+
use_patterns: Optional[bool]
141+
Switches to use pattern-based op suppot by applying MergeCompsite and InlineComposites
142+
passes.
139143
Returns
140144
-------
141145
mod_and_config : Tuple[Module, Dict[str, Any]]
@@ -164,34 +168,74 @@ def partition_for_tensorrt(
164168

165169
if params:
166170
mod["main"] = bind_params_by_name(mod["main"], params)
167-
seq = tvm.transform.Sequential(
168-
[
169-
transform.InferType(),
170-
RemoveDropoutPass(),
171-
transform.RemoveUnusedFunctions(),
172-
transform.ConvertLayout(
173-
{
174-
"nn.conv1d": ["NCW", "default"],
175-
"nn.conv2d": ["NCHW", "default"],
176-
"nn.conv3d": ["NCDHW", "default"],
177-
"nn.conv2d_transpose": ["NCHW", "default"],
178-
}
179-
),
180-
transform.FoldConstant(),
181-
transform.MergeComposite(pattern_table()),
182-
transform.AnnotateTarget("tensorrt"),
183-
transform.MergeCompilerRegions(),
184-
transform.PartitionGraph(),
185-
transform.InlineComposites("tensorrt"),
186-
transform.InferType(),
187-
]
188-
)
171+
172+
seq = get_pass_order(use_patterns)
189173
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
190174
mod = seq(mod)
191175
mod = prune_tensorrt_subgraphs(mod)
192176
return mod, config
193177

194178

179+
def get_pass_order(use_patterns):
180+
"""
181+
Get the pass ordering based on using predicates or patterns.
182+
183+
Parameters
184+
----------
185+
use_patterns: Bool
186+
True if pass needs to work with op patterns
187+
Returns
188+
----------
189+
ret : Sequential
190+
Pass object
191+
"""
192+
return (
193+
tvm.transform.Sequential(
194+
[
195+
transform.InferType(),
196+
RemoveDropoutPass(),
197+
transform.RemoveUnusedFunctions(),
198+
transform.ConvertLayout(
199+
{
200+
"nn.conv1d": ["NCW", "default"],
201+
"nn.conv2d": ["NCHW", "default"],
202+
"nn.conv3d": ["NCDHW", "default"],
203+
"nn.conv2d_transpose": ["NCHW", "default"],
204+
}
205+
),
206+
transform.FoldConstant(),
207+
transform.MergeComposite(pattern_table()),
208+
transform.AnnotateTarget("tensorrt"),
209+
transform.MergeCompilerRegions(),
210+
transform.PartitionGraph(),
211+
transform.InlineComposites("tensorrt"),
212+
transform.InferType(),
213+
]
214+
)
215+
if use_patterns
216+
else tvm.transform.Sequential(
217+
[
218+
transform.InferType(),
219+
RemoveDropoutPass(),
220+
transform.RemoveUnusedFunctions(),
221+
transform.ConvertLayout(
222+
{
223+
"nn.conv1d": ["NCW", "default"],
224+
"nn.conv2d": ["NCHW", "default"],
225+
"nn.conv3d": ["NCDHW", "default"],
226+
"nn.conv2d_transpose": ["NCHW", "default"],
227+
}
228+
),
229+
transform.FoldConstant(),
230+
transform.AnnotateTarget("tensorrt"),
231+
transform.MergeCompilerRegions(),
232+
transform.PartitionGraph(),
233+
transform.InferType(),
234+
]
235+
)
236+
)
237+
238+
195239
def check_dynamism(args, op_name):
196240
"""
197241
Check for dynamism inside any of the args in the op.
@@ -451,7 +495,7 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
451495

452496

453497
@_register_external_dynamic_check_func("nn.batch_matmul")
454-
def batch_matmul_annotate_fn(expr): # pylint: disable=unused-variable
498+
def batch_matmul_annotate_fn(expr):
455499
"""Check if dense is supported by TensorRT."""
456500

457501
args = expr.args

src/relay/transforms/inline_composites.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ namespace tvm {
3434

3535
namespace relay {
3636

37-
class Unmerger : ExprMutator {
37+
class CompositeInliner : public MixedModeMutator {
3838
public:
39-
explicit Unmerger(CallGraphEntry* cur_node, CallGraphNode* call_graph)
39+
explicit CompositeInliner(CallGraphEntry* cur_node, CallGraphNode* call_graph)
4040
: cur_node_(cur_node), call_graph_(call_graph) {}
4141

42-
Expr VisitExpr_(const CallNode* call_node) final {
42+
Expr Rewrite_(const CallNode* call_node) {
4343
Call vanilla_call = GetAnyCall(call_node);
4444
const auto* function_node = vanilla_call->op.as<FunctionNode>();
4545

@@ -60,10 +60,10 @@ class Unmerger : ExprMutator {
6060
return Bind(function_node->body, bind_map);
6161
}
6262

63-
return ExprMutator::VisitExpr_(call_node);
63+
return MixedModeMutator::VisitExpr_(call_node);
6464
}
6565

66-
Function Unmerge(const Function& func) {
66+
Function Inline(const Function& func) {
6767
return WithFields(func, func->params, VisitExpr(func->body));
6868
}
6969

@@ -88,13 +88,13 @@ IRModule InlineComposites(const IRModule& module, runtime::String target) {
8888

8989
if (!base_func->GetAttr<String>(attr::kCompiler).defined() &&
9090
base_func->GetAttr<String>(attr::kCompiler) != target) {
91-
return module;
91+
continue;
9292
}
9393

9494
if (it->GetNameHint() != "main") {
9595
if (const auto* fn = base_func.as<FunctionNode>()) {
9696
auto func = GetRef<Function>(fn);
97-
auto new_func = Unmerger(it, cg.operator->()).Unmerge(func);
97+
auto new_func = CompositeInliner(it, cg.operator->()).Inline(func);
9898
cg->module->Update(it->GetGlobalVar(), new_func);
9999
}
100100
}

0 commit comments

Comments
 (0)