[Transform][Bugfix] Handle non-composite lambda functions in FuseOps#16598
Conversation
Prior to this commit, calling `FuseOpsByPattern` with `annotate_codegen=True` would cause an error when encountering a lambda function. This was caused by the `CompositeFunctionAnnotator` asserting that all `relax::Function` encountered must have the `kComposite` attribute. While this is true for all lambda functions produced by `FuseOpsByPattern`, the user may have defined other lambda functions as well. This commit updates `CompositeFunctionAnnotator` to ignore lambda functions that do not have a `kComposite` attribute.
| Function f_inner = Downcast<Function>(ExprMutator::VisitExpr_(func_node)); | ||
| auto composite_name = func_node->GetAttr<String>(attr::kComposite); | ||
|
|
||
| if (!func_node->GetAttr<String>(attr::kComposite)) { |
There was a problem hiding this comment.
Are non-composite functions visited?
tvm/src/relax/transform/fuse_ops.cc
Line 1224 in 76c1708
There was a problem hiding this comment.
The PatternBasedPartitioner only visits non-composite functions, produces a composite function for each pattern match, and updates the non-composite function to call the newly-generated composite function. Afterwards, the call to CompositeFunctionAnnotator is called. This visits only non-composite functions, finds any relax-to-relax function calls, and asserts that the callee is composite.
The callee will be composite for every function call generated by PatternBasedPartitioner, but that doesn't guarantee that all relax-to-relax function calls have a composite callee. If the IRModule contains a relax-to-relax call prior to PatternBasedPartitioner, that callee may be non-composite. This IRModule would be entirely legal, but would trigger the assert in CompositeFunctionAnnotator.
There was a problem hiding this comment.
So, the problem isn't with calls to inner functions as on line 1224, but with calls to other functions within the IRModule.
There was a problem hiding this comment.
I think the problem is if the callee is not a global var, the callee function will still be visited, so the fix makes sense to me
There was a problem hiding this comment.
Whoops, you're right on that one. It's if there is a inner function in the input IRModule. (Apologies, trying to track too many PRs at one time.)
Prior to this commit, calling
FuseOpsByPatternwithannotate_codegen=Truewould cause an error when encountering a lambda function. This was caused by theCompositeFunctionAnnotatorasserting that allrelax::Functionencountered must have thekCompositeattribute. While this is true for all lambda functions produced byFuseOpsByPattern, the user may have defined other lambda functions as well.This commit updates
CompositeFunctionAnnotatorto ignore lambda functions that do not have akCompositeattribute.