[BYOC][TRT] Add DFPattern support for TRT backend#10759
Conversation
a4f84ce to
53b0c48
Compare
mbaret
left a comment
There was a problem hiding this comment.
I think this pass needs its own unit tests so it can be tested outside of the TRT partitioning flow.
|
|
||
| @_register_external_dynamic_check_func("nn.batch_matmul") | ||
| def batch_matmul_annotate_fn(expr): | ||
| def batch_matmul_annotate_fn(expr): # pylint: disable=unused-variable |
| for tup in pattern_table: | ||
| if len(tup) == 2: | ||
| pattern_name, pattern = tup | ||
| check = lambda extract: True |
There was a problem hiding this comment.
Could you explain this change?
There was a problem hiding this comment.
missmatch in black autoformat it
|
|
||
| namespace relay { | ||
|
|
||
| class Unmerger : ExprMutator { |
There was a problem hiding this comment.
MixedModeMutator is now preferred where possible.
| Function gv = GetRef<Function>(function_var_node); | ||
| const auto* fn = gv.as<FunctionNode>(); |
There was a problem hiding this comment.
Is this needed - it looks like we already start with the FunctionNode?
|
|
||
| Expr VisitExpr_(const CallNode* call_node) final { | ||
| Call vanilla_call = GetAnyCall(call_node); | ||
| const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>(); |
| // Attrs need to be empty at this point to avoid propagating Composite and | ||
| // PartitionedFromPattern that fiddling TRT code gen for registered ops. | ||
| auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, {}); | ||
| return Bind(func->body, bind_map); |
There was a problem hiding this comment.
Not sure I understand this, why can't we just do
return Bind(fn->body, bind_map);
|
|
||
| if (!base_func->GetAttr<String>(attr::kCompiler).defined() && | ||
| base_func->GetAttr<String>(attr::kCompiler) != target) { | ||
| return module; |
There was a problem hiding this comment.
I think it'd be better to continue; here rather than return, otherwise it seems if any partitioning for a different target has taken place, this will bail out.
| */ | ||
|
|
||
| /*! | ||
| * \file src/relay/transforms/unmerge_composites.cc |
There was a problem hiding this comment.
Personal preference for the name here would be either InlineComposite or RemoveComposite, not a huge deal though, if no one else agrees we can keep it as Unmerge.
There was a problem hiding this comment.
InlineComposite makes sense, I will rename it
|
|
||
| /*! | ||
| * \file src/relay/transforms/unmerge_composites.cc | ||
| * \brief Undo the partioned graphs originate from merge composite. |
There was a problem hiding this comment.
I think 'Inline composite functions for a given target' describes this a bit better.
53b0c48 to
a4d95c7
Compare
|
PTAL |
a961ef8 to
a3dc554
Compare
|
I still think this needs a couple of simple unit tests to confirm the behaviour. Also ping @mbs-octoml if you want to take a quick look. |
|
@mbaret PTAL. Under |
| print("merge composite reusult") | ||
| print(result) | ||
| print("---------------------") |
There was a problem hiding this comment.
We should probably omit the prints.
| def expected(): | ||
| a = relay.var("a", shape=(10, 10)) | ||
| b = relay.var("b", shape=(10, 10)) | ||
|
|
||
| # add_relu function | ||
| in_1 = relay.var("in_1", shape=(10, 10)) | ||
| in_2 = relay.var("in_2", shape=(10, 10)) | ||
| add_node = relay.add(in_1, in_2) | ||
| relu_node = relay.nn.relu(add_node) | ||
| add_relu = relay.Function([in_1, in_2], relu_node) | ||
| return add_relu |
There was a problem hiding this comment.
I think this is the same as before() (given a and b aren't used). If all we really want to test is that doing InlineComposites undoes MergeComposite, we can probably just test that the result is equal to the input.
| """Utility function to check inline composites results.""" | ||
| result = run_opt_pass( | ||
| graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude | ||
| ) |
There was a problem hiding this comment.
I think we should put some form of check here just to confirm that a composite function has been created (so we know MergeComposite didn't just skip everything if for instance there was a pattern error).
| relu relu | ||
|
|
||
| """ | ||
| pattern_table = [("add", make_add_pattern()), ("nn.relu", make_relu_pattern())] |
There was a problem hiding this comment.
This doesn't seem to match the description above.
| """ | ||
|
|
||
|
|
||
| def make_conv_bias_relu_pattern(): |
There was a problem hiding this comment.
This pattern doesn't seem to be used, I think either add a test for it or remove.
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry. Adds and extends the following: In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite pattern. Adds test_inline_composites.py which tests the newly introduced pass. Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules. This is to ensure backwards compatibility."
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry. Adds and extends the following: In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite pattern. Adds test_inline_composites.py which tests the newly introduced pass. Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules. This is to ensure backwards compatibility."
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry. Adds and extends the following: In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite pattern. Adds test_inline_composites.py which tests the newly introduced pass. Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules. This is to ensure backwards compatibility."
This PR adds
DFPatternsupport for the TRT backend without removing the existing predicate registry.Adds and extends the following:
tensorrt.py: Add apattern_tablefor all the supported ops and consumes the pre-existing op_registry checksunmerge_composites.cc. This is required for the TRT backend as it expects a single primitive function to work with, while theMergeCompositeandPartitionGraphwill produce a single function for eachCompositepattern.test_inline_composites.pywhich tests the newly introduced pass.Both the pattern-based and predicate-based pass sequences produce syntactically equivalent
IRModules.This is to ensure backwards compatibility."
Original Pass orderding:
Pass ordering with MergeComposites and UnmergeComposites:
@mbs-octoml @mbaret @masahi