Skip to content

Commit b3f7800

Browse files
committed
Refactor VisitExpr_
1 parent 25bbab1 commit b3f7800

1 file changed

Lines changed: 163 additions & 35 deletions

File tree

src/relax/transform/canonicalize_shape_expr.cc

Lines changed: 163 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -62,53 +62,153 @@ bool IsCanonicalPrimExpr(const PrimExpr& expr) {
6262
* \brief Mutator to canonicalize ShapeExpr in struct info
6363
*
6464
* This pass handles ShapeExpr canonicalization by:
65-
* 1. Detecting compound PrimExpr in ShapeExpr dimensions
66-
* 2. Lifting them into separate ShapeExpr bindings
65+
* 1. Detecting compound PrimExpr in variable struct_info
66+
* 2. Emitting ShapeExpr bindings to compute expressions
6767
* 3. Using MatchCast to extract values into fresh symbolic tir::Var
68-
* 4. Replacing compound expressions with these canonical vars
68+
* 4. Replacing compound expressions with these canonical vars in struct_info
6969
*/
7070
class ShapeExprCanonicalizer : public ExprMutator {
7171
public:
7272
using ExprMutator::VisitExpr_;
7373

7474
Expr VisitExpr_(const FunctionNode* func) override {
7575
// Reset state for each function
76-
auto cached_compound_to_var = compound_expr_to_var_;
77-
auto cached_counter = symbolic_var_counter_;
76+
symbolic_var_counter_ = 0;
77+
compound_expr_to_var_.clear();
78+
emitted_bindings_.clear();
7879

79-
auto result = ExprMutator::VisitExpr_(func);
80+
// Process the function body
81+
Expr new_body = VisitExpr(func->body);
8082

81-
compound_expr_to_var_ = cached_compound_to_var;
82-
symbolic_var_counter_ = cached_counter;
83+
if (new_body.same_as(func->body)) {
84+
return ffi::GetRef<Function>(func);
85+
}
8386

84-
return result;
87+
return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs,
88+
func->span);
89+
}
90+
91+
void VisitBinding_(const VarBindingNode* binding) override {
92+
// Emit canonicalization bindings before processing the binding
93+
auto original_sinfo = GetStructInfo(binding->var);
94+
if (NeedsCanonicalization(original_sinfo)) {
95+
CanonicalizeStructInfoAndEmit(original_sinfo);
96+
}
97+
98+
// Compute the updated var and value
99+
Expr new_value = this->VisitExpr(binding->value);
100+
Var new_var = this->VisitVarDef(binding->var);
101+
102+
// Update var_remap_ to register the new variable
103+
this->var_remap_[binding->var->vid] = new_var;
104+
this->var_remap_[new_var->vid] = new_var;
105+
106+
// If the var's struct_info changed and no longer matches the value, emit MatchCast
107+
StructInfo var_sinfo = GetStructInfo(new_var);
108+
StructInfo value_sinfo = GetStructInfo(new_value);
109+
StructuralEqual struct_equal;
110+
if (!struct_equal(var_sinfo, value_sinfo)) {
111+
MatchCast match_cast = MatchCast(new_var, new_value, var_sinfo);
112+
builder_->EmitNormalized(match_cast);
113+
builder_->AddDefinitionToScope(match_cast->var);
114+
} else {
115+
builder_->EmitNormalized(VarBinding(new_var, new_value));
116+
}
117+
}
118+
119+
void VisitBinding_(const MatchCastNode* binding) override {
120+
// Emit canonicalization bindings before processing the binding
121+
auto original_sinfo = GetStructInfo(binding->var);
122+
if (NeedsCanonicalization(original_sinfo)) {
123+
CanonicalizeStructInfoAndEmit(original_sinfo);
124+
}
125+
126+
ExprMutator::VisitBinding_(binding);
85127
}
86128

87-
/*!
88-
* \brief Override VisitVarDef to canonicalize struct_info
89-
*
90-
* This is where we intercept variable definitions and canonicalize any
91-
* compound PrimExpr in their TensorStructInfo shapes.
92-
*/
93129
Var VisitVarDef(const Var& var) override {
94130
auto sinfo = GetStructInfo(var);
95-
96-
// Check if we need to canonicalize the struct_info
97131
auto canonical_sinfo = CanonicalizeStructInfo(sinfo);
98132

99133
if (canonical_sinfo.same_as(sinfo)) {
100-
// No changes needed
101134
return ExprMutator::VisitVarDef(var);
102135
}
103136

104-
// Create a new var with canonicalized strcut_info
137+
// Create a new var with canonicalized struct_info
138+
Var canonical_var;
105139
if (var->IsInstance<DataflowVarNode>()) {
106-
return DataflowVar(var->vid, canonical_sinfo, var->span);
140+
canonical_var = DataflowVar(var->vid, canonical_sinfo, var->span);
141+
} else {
142+
canonical_var = Var(var->vid, canonical_sinfo, var->span);
107143
}
108-
return Var(var->vid, canonical_sinfo, var->span);
144+
145+
return ExprMutator::VisitVarDef(canonical_var);
109146
}
110147

111148
private:
149+
/*!
150+
* \brief Check if struct_info needs canonicalization
151+
*/
152+
bool NeedsCanonicalization(const StructInfo& sinfo) {
153+
if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
154+
if (!tensor_sinfo->shape.defined()) {
155+
return false;
156+
}
157+
auto shape_expr = tensor_sinfo->shape.as<ShapeExprNode>();
158+
if (!shape_expr) {
159+
return false;
160+
}
161+
for (const PrimExpr& dim : shape_expr->values) {
162+
if (!IsCanonicalPrimExpr(dim)) {
163+
return true;
164+
}
165+
}
166+
return false;
167+
} else if (auto tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
168+
for (const StructInfo& field : tuple_sinfo->fields) {
169+
if (NeedsCanonicalization(field)) {
170+
return true;
171+
}
172+
}
173+
return false;
174+
}
175+
return false;
176+
}
177+
178+
/*!
179+
* Canonicalize struct info and emit necessary bindings
180+
*/
181+
void CanonicalizeStructInfoAndEmit(const StructInfo& sinfo) {
182+
if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
183+
CanonicalizeTensorStructInfoAndEmit(ffi::GetRef<TensorStructInfo>(tensor_sinfo));
184+
} else if (auto tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
185+
for (const StructInfo& field : tuple_sinfo->fields) {
186+
CanonicalizeStructInfoAndEmit(field);
187+
}
188+
}
189+
}
190+
191+
/*!
192+
* Canonicalize tensor struct info and emit necessary bindings
193+
*/
194+
void CanonicalizeTensorStructInfoAndEmit(const TensorStructInfo& sinfo) {
195+
if (!sinfo->shape.defined()) {
196+
return;
197+
}
198+
199+
auto shape_expr = sinfo->shape.as<ShapeExprNode>();
200+
if (!shape_expr) {
201+
return;
202+
}
203+
204+
// Emit bindings for each compound dimension
205+
for (const PrimExpr& dim : shape_expr->values) {
206+
if (!IsCanonicalPrimExpr(dim)) {
207+
CanonicalizeDimension(dim);
208+
}
209+
}
210+
}
211+
112212
/*!
113213
* \brief Canonicalize struct info by lifting compound shape expressions
114214
*/
@@ -140,7 +240,7 @@ class ShapeExprCanonicalizer : public ExprMutator {
140240
bool changed = false;
141241

142242
for (const PrimExpr& dim : shape_expr->values) {
143-
PrimExpr canonical_dim = CanonicalizeDimension(dim);
243+
PrimExpr canonical_dim = GetCanonicalDimension(dim);
144244
canonical_dims.push_back(canonical_dim);
145245
changed |= !canonical_dim.same_as(dim);
146246
}
@@ -174,15 +274,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
174274
}
175275

176276
/*!
177-
* \brief Canonicalize a single shape dimension
178-
*
179-
* If the dimension is a compound PrimExpr:
180-
* 1. Emit a ShapeExpr binding containing the compound expression
181-
* 2. Create a fresh symbolic tir::Var
182-
* 3. Emit a MatchCast to bind the computed value to the symbolic var
183-
* 4. Return the symbolic var
277+
* \brief Get the canonical form of a dimension (returns the symbolic var if already emitted)
184278
*/
185-
PrimExpr CanonicalizeDimension(const PrimExpr& dim) {
279+
PrimExpr GetCanonicalDimension(const PrimExpr& dim) {
186280
// If already canonical, return as is
187281
if (IsCanonicalPrimExpr(dim)) {
188282
return dim;
@@ -193,9 +287,42 @@ class ShapeExprCanonicalizer : public ExprMutator {
193287
return it->second;
194288
}
195289

196-
// Create a fresh symbolic variable
290+
// Create a fresh symbolic variable, but don't emit yet
197291
tir::Var symbolic_var = CreateFreshSymbolicVar(dim->dtype);
198292

293+
compound_expr_to_var_[dim] = symbolic_var;
294+
295+
return symbolic_var;
296+
}
297+
298+
/*!
299+
* \brief Emit bindings for a single compound dimension
300+
*
301+
* If the dimension is a compound PrimExpr:
302+
* 1. Emit a ShapeExpr binding containing the compound expression
303+
* 2. Create a fresh symbolic tir::Var
304+
* 3. Emit a MatchCast to bind the computed value to the symbolic var
305+
*/
306+
void CanonicalizeDimension(const PrimExpr& dim) {
307+
// If already canonical, nothing to emit
308+
if (IsCanonicalPrimExpr(dim)) {
309+
return;
310+
}
311+
312+
// Check If we've already emitted bindings for this expression
313+
auto it = compound_expr_to_var_.find(dim);
314+
if (it == compound_expr_to_var_.end()) {
315+
// This should not happen if GetCanonicalDimension was called first
316+
return;
317+
}
318+
319+
// Check if we've already emitted the bindings
320+
if (emitted_bindings_.count(dim)) {
321+
return;
322+
}
323+
324+
tir::Var symbolic_var = it->second;
325+
199326
// Emit shape binding: shape_var = R.shape([compound_expr])
200327
ShapeExpr shape_value({dim});
201328
Var shape_var = builder_->Emit(shape_value);
@@ -206,10 +333,8 @@ class ShapeExprCanonicalizer : public ExprMutator {
206333
Var match_cast_var("_", match_sinfo);
207334
builder_->EmitNormalized(MatchCast(match_cast_var, shape_var, match_sinfo));
208335

209-
// Cache the mapping to avoid duplicate bindings
210-
compound_expr_to_var_[dim] = symbolic_var;
211-
212-
return symbolic_var;
336+
// Mark as emitted
337+
emitted_bindings_.insert(dim);
213338
}
214339

215340
/*!
@@ -223,6 +348,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
223348
// Cache to avoid creating duplicate bindings for the same compound expression
224349
std::unordered_map<PrimExpr, tir::Var, StructuralHash, StructuralEqual> compound_expr_to_var_;
225350

351+
// Track which compound expressions have had their bindings emitted
352+
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> emitted_bindings_;
353+
226354
// Counter for generating unique symbolic variable names
227355
int symbolic_var_counter_ = 0;
228356
};

0 commit comments

Comments
 (0)