@@ -106,6 +106,7 @@ def partition_for_tensorrt(
106106 max_workspace_size = 1 << 30 ,
107107 use_fp16 = False ,
108108 use_uint8 = False ,
109+ use_patterns = False ,
109110):
110111 """Partition the graph greedily offloading supported operators to TensorRT.
111112
@@ -136,6 +137,9 @@ def partition_for_tensorrt(
136137 lower runtime, or if no low-precision implementation exists.
137138 use_uint8: Optional[bool]
138139 Allows, TRT to automatically convert FP32 inputs to UINT8.
140+ use_patterns: Optional[bool]
141+ Switches to use pattern-based op suppot by applying MergeCompsite and InlineComposites
142+ passes.
139143 Returns
140144 -------
141145 mod_and_config : Tuple[Module, Dict[str, Any]]
@@ -164,34 +168,74 @@ def partition_for_tensorrt(
164168
165169 if params :
166170 mod ["main" ] = bind_params_by_name (mod ["main" ], params )
167- seq = tvm .transform .Sequential (
168- [
169- transform .InferType (),
170- RemoveDropoutPass (),
171- transform .RemoveUnusedFunctions (),
172- transform .ConvertLayout (
173- {
174- "nn.conv1d" : ["NCW" , "default" ],
175- "nn.conv2d" : ["NCHW" , "default" ],
176- "nn.conv3d" : ["NCDHW" , "default" ],
177- "nn.conv2d_transpose" : ["NCHW" , "default" ],
178- }
179- ),
180- transform .FoldConstant (),
181- transform .MergeComposite (pattern_table ()),
182- transform .AnnotateTarget ("tensorrt" ),
183- transform .MergeCompilerRegions (),
184- transform .PartitionGraph (),
185- transform .InlineComposites ("tensorrt" ),
186- transform .InferType (),
187- ]
188- )
171+
172+ seq = get_pass_order (use_patterns )
189173 with tvm .transform .PassContext (opt_level = 3 , config = {"relay.ext.tensorrt.options" : config }):
190174 mod = seq (mod )
191175 mod = prune_tensorrt_subgraphs (mod )
192176 return mod , config
193177
194178
179+ def get_pass_order (use_patterns ):
180+ """
181+ Get the pass ordering based on using predicates or patterns.
182+
183+ Parameters
184+ ----------
185+ use_patterns: Bool
186+ True if pass needs to work with op patterns
187+ Returns
188+ ----------
189+ ret : Sequential
190+ Pass object
191+ """
192+ return (
193+ tvm .transform .Sequential (
194+ [
195+ transform .InferType (),
196+ RemoveDropoutPass (),
197+ transform .RemoveUnusedFunctions (),
198+ transform .ConvertLayout (
199+ {
200+ "nn.conv1d" : ["NCW" , "default" ],
201+ "nn.conv2d" : ["NCHW" , "default" ],
202+ "nn.conv3d" : ["NCDHW" , "default" ],
203+ "nn.conv2d_transpose" : ["NCHW" , "default" ],
204+ }
205+ ),
206+ transform .FoldConstant (),
207+ transform .MergeComposite (pattern_table ()),
208+ transform .AnnotateTarget ("tensorrt" ),
209+ transform .MergeCompilerRegions (),
210+ transform .PartitionGraph (),
211+ transform .InlineComposites ("tensorrt" ),
212+ transform .InferType (),
213+ ]
214+ )
215+ if use_patterns
216+ else tvm .transform .Sequential (
217+ [
218+ transform .InferType (),
219+ RemoveDropoutPass (),
220+ transform .RemoveUnusedFunctions (),
221+ transform .ConvertLayout (
222+ {
223+ "nn.conv1d" : ["NCW" , "default" ],
224+ "nn.conv2d" : ["NCHW" , "default" ],
225+ "nn.conv3d" : ["NCDHW" , "default" ],
226+ "nn.conv2d_transpose" : ["NCHW" , "default" ],
227+ }
228+ ),
229+ transform .FoldConstant (),
230+ transform .AnnotateTarget ("tensorrt" ),
231+ transform .MergeCompilerRegions (),
232+ transform .PartitionGraph (),
233+ transform .InferType (),
234+ ]
235+ )
236+ )
237+
238+
195239def check_dynamism (args , op_name ):
196240 """
197241 Check for dynamism inside any of the args in the op.
@@ -451,7 +495,7 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
451495
452496
453497@_register_external_dynamic_check_func ("nn.batch_matmul" )
454- def batch_matmul_annotate_fn (expr ): # pylint: disable=unused-variable
498+ def batch_matmul_annotate_fn (expr ):
455499 """Check if dense is supported by TensorRT."""
456500
457501 args = expr .args
0 commit comments