Skip to content

Commit bbd6752

Browse files
committed
Refactor: CanonicalizeShapeExpr
1 parent 99252df commit bbd6752

File tree

7 files changed

+268
-281
lines changed

7 files changed

+268
-281
lines changed

python/tvm/relax/backend/cpu_generic/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume
5252
relax.transform.LowerAllocTensor(),
5353
relax.transform.KillAfterLastUse(),
5454
relax.transform.LowerRuntimeBuiltin(),
55-
relax.transform.ComputePrimValue(),
5655
relax.transform.CanonicalizeShapeExpr(),
56+
relax.transform.ComputePrimValue(),
5757
relax.transform.VMShapeLower(),
5858
relax.transform.AttachGlobalSymbol(),
5959
]

python/tvm/relax/backend/cuda/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume
6464
relax.transform.LowerAllocTensor(),
6565
relax.transform.KillAfterLastUse(),
6666
relax.transform.LowerRuntimeBuiltin(),
67-
relax.transform.ComputePrimValue(),
6867
relax.transform.CanonicalizeShapeExpr(),
68+
relax.transform.ComputePrimValue(),
6969
relax.transform.VMShapeLower(),
7070
relax.transform.AttachGlobalSymbol(),
7171
]

python/tvm/relax/backend/gpu_generic/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume
6363
relax.transform.LowerAllocTensor(),
6464
relax.transform.KillAfterLastUse(),
6565
relax.transform.LowerRuntimeBuiltin(),
66-
relax.transform.ComputePrimValue(),
6766
relax.transform.CanonicalizeShapeExpr(),
67+
relax.transform.ComputePrimValue(),
6868
relax.transform.VMShapeLower(),
6969
relax.transform.AttachGlobalSymbol(),
7070
]

python/tvm/relax/backend/rocm/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume
6363
relax.transform.LowerAllocTensor(),
6464
relax.transform.KillAfterLastUse(),
6565
relax.transform.LowerRuntimeBuiltin(),
66-
relax.transform.ComputePrimValue(),
6766
relax.transform.CanonicalizeShapeExpr(),
67+
relax.transform.ComputePrimValue(),
6868
relax.transform.VMShapeLower(),
6969
relax.transform.AttachGlobalSymbol(),
7070
]

python/tvm/relax/transform/transform.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -736,17 +736,23 @@ def FoldConstant() -> tvm.ir.transform.Pass:
736736

737737

738738
def CanonicalizeShapeExpr() -> tvm.ir.transform.Pass:
739-
"""Canonicalize ShapeExpr by lifting compound PrimExpr into separate bindings.
739+
"""Canonicalize ShapeExpr by replacing compound PrimExpr with fresh symbolic variables.
740740
741741
VMShapeLower can only handle ShapeExpr where each dimension is either:
742742
- IntImm (concrete integer constant)
743-
- tir::Var (symbolic variable)
743+
- tir::Var (symbolic variable from function parameters or match_cast)
744744
745-
This pass lifts compound PrimExpr (e.g., n+1, 4*n*m, etc.) into separate shape bindings
746-
with MatchCast to extract symbolic variables, ensuring VMShapeLower receives only
747-
canonical shape expressions.
745+
This pass transforms compound PrimExpr (e.g., n+1, 4*n*m) by:
746+
1. Creating a fresh tir::Var for each compound expression
747+
2. Emitting a MatchCast that binds the fresh var to a PrimValue computing the expression
748+
3. Replacing the compound expression in ShapeExpr with teh fresh var
748749
749-
This pass should be applied after ComputePrimValue and before VMShapeLower.
750+
Example transformation:
751+
Before: y = R.zeros(R.shape([n + 1]), dtype="float32")
752+
After: _s0_pv: R.Prim(value=_s0) = R.match_cast(R.prim_value(n+1), R.Prim(value=_s0))
753+
y = R.zeros(R.shape([_s0]), dtype="float32")
754+
755+
This pass should be applied before ComputePrimValue and before VMShapeLower.
750756
751757
Returns
752758
-------

0 commit comments

Comments
 (0)