@@ -62,53 +62,153 @@ bool IsCanonicalPrimExpr(const PrimExpr& expr) {
6262 * \brief Mutator to canonicalize ShapeExpr in struct info
6363 *
6464 * This pass handles ShapeExpr canonicalization by:
65- * 1. Detecting compound PrimExpr in ShapeExpr dimensions
66- * 2. Lifting them into separate ShapeExpr bindings
65+ * 1. Detecting compound PrimExpr in variable struct_info
66+ * 2. Emitting ShapeExpr bindings to compute expressions
6767 * 3. Using MatchCast to extract values into fresh symbolic tir::Var
68- * 4. Replacing compound expressions with these canonical vars
68+ * 4. Replacing compound expressions with these canonical vars in struct_info
6969 */
7070class ShapeExprCanonicalizer : public ExprMutator {
7171 public:
7272 using ExprMutator::VisitExpr_;
7373
7474 Expr VisitExpr_ (const FunctionNode* func) override {
7575 // Reset state for each function
76- auto cached_compound_to_var = compound_expr_to_var_;
77- auto cached_counter = symbolic_var_counter_;
76+ symbolic_var_counter_ = 0 ;
77+ compound_expr_to_var_.clear ();
78+ emitted_bindings_.clear ();
7879
79- auto result = ExprMutator::VisitExpr_ (func);
80+ // Process the function body
81+ Expr new_body = VisitExpr (func->body );
8082
81- compound_expr_to_var_ = cached_compound_to_var;
82- symbolic_var_counter_ = cached_counter;
83+ if (new_body.same_as (func->body )) {
84+ return ffi::GetRef<Function>(func);
85+ }
8386
84- return result;
87+ return Function (func->params , new_body, func->ret_struct_info , func->is_pure , func->attrs ,
88+ func->span );
89+ }
90+
91+ void VisitBinding_ (const VarBindingNode* binding) override {
92+ // Emit canonicalization bindings before processing the binding
93+ auto original_sinfo = GetStructInfo (binding->var );
94+ if (NeedsCanonicalization (original_sinfo)) {
95+ CanonicalizeStructInfoAndEmit (original_sinfo);
96+ }
97+
98+ // Compute the updated var and value
99+ Expr new_value = this ->VisitExpr (binding->value );
100+ Var new_var = this ->VisitVarDef (binding->var );
101+
102+ // Update var_remap_ to register the new variable
103+ this ->var_remap_ [binding->var ->vid ] = new_var;
104+ this ->var_remap_ [new_var->vid ] = new_var;
105+
106+ // If the var's struct_info changed and no longer matches the value, emit MatchCast
107+ StructInfo var_sinfo = GetStructInfo (new_var);
108+ StructInfo value_sinfo = GetStructInfo (new_value);
109+ StructuralEqual struct_equal;
110+ if (!struct_equal (var_sinfo, value_sinfo)) {
111+ MatchCast match_cast = MatchCast (new_var, new_value, var_sinfo);
112+ builder_->EmitNormalized (match_cast);
113+ builder_->AddDefinitionToScope (match_cast->var );
114+ } else {
115+ builder_->EmitNormalized (VarBinding (new_var, new_value));
116+ }
117+ }
118+
119+ void VisitBinding_ (const MatchCastNode* binding) override {
120+ // Emit canonicalization bindings before processing the binding
121+ auto original_sinfo = GetStructInfo (binding->var );
122+ if (NeedsCanonicalization (original_sinfo)) {
123+ CanonicalizeStructInfoAndEmit (original_sinfo);
124+ }
125+
126+ ExprMutator::VisitBinding_ (binding);
85127 }
86128
87- /* !
88- * \brief Override VisitVarDef to canonicalize struct_info
89- *
90- * This is where we intercept variable definitions and canonicalize any
91- * compound PrimExpr in their TensorStructInfo shapes.
92- */
93129 Var VisitVarDef (const Var& var) override {
94130 auto sinfo = GetStructInfo (var);
95-
96- // Check if we need to canonicalize the struct_info
97131 auto canonical_sinfo = CanonicalizeStructInfo (sinfo);
98132
99133 if (canonical_sinfo.same_as (sinfo)) {
100- // No changes needed
101134 return ExprMutator::VisitVarDef (var);
102135 }
103136
104- // Create a new var with canonicalized strcut_info
137+ // Create a new var with canonicalized struct_info
138+ Var canonical_var;
105139 if (var->IsInstance <DataflowVarNode>()) {
106- return DataflowVar (var->vid , canonical_sinfo, var->span );
140+ canonical_var = DataflowVar (var->vid , canonical_sinfo, var->span );
141+ } else {
142+ canonical_var = Var (var->vid , canonical_sinfo, var->span );
107143 }
108- return Var (var->vid , canonical_sinfo, var->span );
144+
145+ return ExprMutator::VisitVarDef (canonical_var);
109146 }
110147
111148 private:
149+ /* !
150+ * \brief Check if struct_info needs canonicalization
151+ */
152+ bool NeedsCanonicalization (const StructInfo& sinfo) {
153+ if (auto tensor_sinfo = sinfo.as <TensorStructInfoNode>()) {
154+ if (!tensor_sinfo->shape .defined ()) {
155+ return false ;
156+ }
157+ auto shape_expr = tensor_sinfo->shape .as <ShapeExprNode>();
158+ if (!shape_expr) {
159+ return false ;
160+ }
161+ for (const PrimExpr& dim : shape_expr->values ) {
162+ if (!IsCanonicalPrimExpr (dim)) {
163+ return true ;
164+ }
165+ }
166+ return false ;
167+ } else if (auto tuple_sinfo = sinfo.as <TupleStructInfoNode>()) {
168+ for (const StructInfo& field : tuple_sinfo->fields ) {
169+ if (NeedsCanonicalization (field)) {
170+ return true ;
171+ }
172+ }
173+ return false ;
174+ }
175+ return false ;
176+ }
177+
178+ /* !
179+ * Canonicalize struct info and emit necessary bindings
180+ */
181+ void CanonicalizeStructInfoAndEmit (const StructInfo& sinfo) {
182+ if (auto tensor_sinfo = sinfo.as <TensorStructInfoNode>()) {
183+ CanonicalizeTensorStructInfoAndEmit (ffi::GetRef<TensorStructInfo>(tensor_sinfo));
184+ } else if (auto tuple_sinfo = sinfo.as <TupleStructInfoNode>()) {
185+ for (const StructInfo& field : tuple_sinfo->fields ) {
186+ CanonicalizeStructInfoAndEmit (field);
187+ }
188+ }
189+ }
190+
191+ /* !
192+ * Canonicalize tensor struct info and emit necessary bindings
193+ */
194+ void CanonicalizeTensorStructInfoAndEmit (const TensorStructInfo& sinfo) {
195+ if (!sinfo->shape .defined ()) {
196+ return ;
197+ }
198+
199+ auto shape_expr = sinfo->shape .as <ShapeExprNode>();
200+ if (!shape_expr) {
201+ return ;
202+ }
203+
204+ // Emit bindings for each compound dimension
205+ for (const PrimExpr& dim : shape_expr->values ) {
206+ if (!IsCanonicalPrimExpr (dim)) {
207+ CanonicalizeDimension (dim);
208+ }
209+ }
210+ }
211+
112212 /* !
113213 * \brief Canonicalize struct info by lifting compound shape expressions
114214 */
@@ -140,7 +240,7 @@ class ShapeExprCanonicalizer : public ExprMutator {
140240 bool changed = false ;
141241
142242 for (const PrimExpr& dim : shape_expr->values ) {
143- PrimExpr canonical_dim = CanonicalizeDimension (dim);
243+ PrimExpr canonical_dim = GetCanonicalDimension (dim);
144244 canonical_dims.push_back (canonical_dim);
145245 changed |= !canonical_dim.same_as (dim);
146246 }
@@ -174,15 +274,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
174274 }
175275
176276 /* !
177- * \brief Canonicalize a single shape dimension
178- *
179- * If the dimension is a compound PrimExpr:
180- * 1. Emit a ShapeExpr binding containing the compound expression
181- * 2. Create a fresh symbolic tir::Var
182- * 3. Emit a MatchCast to bind the computed value to the symbolic var
183- * 4. Return the symbolic var
277+ * \brief Get the canonical form of a dimension (returns the symbolic var if already emitted)
184278 */
185- PrimExpr CanonicalizeDimension (const PrimExpr& dim) {
279+ PrimExpr GetCanonicalDimension (const PrimExpr& dim) {
186280 // If already canonical, return as is
187281 if (IsCanonicalPrimExpr (dim)) {
188282 return dim;
@@ -193,9 +287,42 @@ class ShapeExprCanonicalizer : public ExprMutator {
193287 return it->second ;
194288 }
195289
196- // Create a fresh symbolic variable
290+ // Create a fresh symbolic variable, but don't emit yet
197291 tir::Var symbolic_var = CreateFreshSymbolicVar (dim->dtype );
198292
293+ compound_expr_to_var_[dim] = symbolic_var;
294+
295+ return symbolic_var;
296+ }
297+
298+ /* !
299+ * \brief Emit bindings for a single compound dimension
300+ *
301+ * If the dimension is a compound PrimExpr:
302+ * 1. Emit a ShapeExpr binding containing the compound expression
303+ * 2. Create a fresh symbolic tir::Var
304+ * 3. Emit a MatchCast to bind the computed value to the symbolic var
305+ */
306+ void CanonicalizeDimension (const PrimExpr& dim) {
307+ // If already canonical, nothing to emit
308+ if (IsCanonicalPrimExpr (dim)) {
309+ return ;
310+ }
311+
312+ // Check If we've already emitted bindings for this expression
313+ auto it = compound_expr_to_var_.find (dim);
314+ if (it == compound_expr_to_var_.end ()) {
315+ // This should not happen if GetCanonicalDimension was called first
316+ return ;
317+ }
318+
319+ // Check if we've already emitted the bindings
320+ if (emitted_bindings_.count (dim)) {
321+ return ;
322+ }
323+
324+ tir::Var symbolic_var = it->second ;
325+
199326 // Emit shape binding: shape_var = R.shape([compound_expr])
200327 ShapeExpr shape_value ({dim});
201328 Var shape_var = builder_->Emit (shape_value);
@@ -206,10 +333,8 @@ class ShapeExprCanonicalizer : public ExprMutator {
206333 Var match_cast_var (" _" , match_sinfo);
207334 builder_->EmitNormalized (MatchCast (match_cast_var, shape_var, match_sinfo));
208335
209- // Cache the mapping to avoid duplicate bindings
210- compound_expr_to_var_[dim] = symbolic_var;
211-
212- return symbolic_var;
336+ // Mark as emitted
337+ emitted_bindings_.insert (dim);
213338 }
214339
215340 /* !
@@ -223,6 +348,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
223348 // Cache to avoid creating duplicate bindings for the same compound expression
224349 std::unordered_map<PrimExpr, tir::Var, StructuralHash, StructuralEqual> compound_expr_to_var_;
225350
351+ // Track which compound expressions have had their bindings emitted
352+ std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> emitted_bindings_;
353+
226354 // Counter for generating unique symbolic variable names
227355 int symbolic_var_counter_ = 0 ;
228356};
0 commit comments