[Transform] Improve symbolic variable handling in FuseOps#16450
[Transform] Improve symbolic variable handling in FuseOps#16450Lunderberg wants to merge 1 commit intoapache:mainfrom
Conversation
|
Ideally we don't want to change FuseOps behavior, since in cases where expressions are intermediate (e.g. intermediate compute include values that contains exprs like n * 4). This is because we should get maybe we should look into compose them? FuseOps first then rewrite signatures |
|
I could see having a post-processing pass to update the signature, maybe as an extension of Though, could you expand on what you mean by intermediate expressions? In either case, whether implemented in |
b34ffe9 to
b014332
Compare
b014332 to
3556c4f
Compare
|
Rebased onto main to resolve conflicts. For long-term, I think I agree that it would be cleaner and more general-purpose to have the functionality separated out into three distinct passes:
|
3556c4f to
ebb6278
Compare
|
I've separated the first commit of this PR branch into an independent PR (#16637), as the bugfix it provides is independent of the concerns raised, and does not require the not-yet-implemented |
Prior to this commit, `FuseOps` and `FuseOpsByPattern` exposed a symbolic variable to the fused function if it was used within the fused function, but wasn't inferable from other parameter shapes. While this prevents undefined symbolic variables, it can cause issues for downstream use of `CodegenJSON`, which requires all arguments to be tensors, or tuple of tensors. Frequently, all uses of a non-inferable symbolic shape occur within a symbolic expression that can be inferred. For example, a function that takes `arg: R.Tensor([N+1])` and returns `R.add(arg, R.const(1))` cannot infer `N`. However, all occurrences of `N` occur as part of the expression `N+1`, and the value of `N+1` can be inferred. Therefore, if we replace `N+1` with `M`, the additional `ShapeTuple` argument isn't required.
ebb6278 to
db735e3
Compare
|
New functionality implemented in #16450, which would hoist out the common subexpressions. After it lands, this PR can be updated to make use of it. |
Prior to this commit,
FuseOpsandFuseOpsByPatternexposed a symbolic variable to the fused function if it was used within the fused function, but wasn't inferable from other parameter shapes. While this prevents undefined symbolic variables, it can cause issues for downstream use ofCodegenJSON, which requires all arguments to be tensors, or tuple of tensors.Frequently, all uses of a non-inferable symbolic shape occur within a symbolic expression that can be inferred. For example, a function that takes
arg: R.Tensor([N+1])and returnsR.add(arg, R.const(1))cannot inferN. However, all occurrences ofNoccur as part of the expressionN+1, and the value ofN+1can be inferred. Therefore, if we replaceN+1withM, the additionalShapeTupleargument isn't required.In addition, prior to this commit, the
CompositeFunctionAnnotatorvisited the body of functions without the parameters being considered in-scope. As a result,EraseToWellDefinedwould remove known shapes from the function body'sStructInfo.