Skip to content

Commit ed9aa56

Browse files
authored
[Relax][Analysis] Handle recursive functions in CollectVarUsage (#17224)
* [Relax][Analysis] Handle recursive functions in CollectVarUsage Prior to this commit, the `relax::analysis::CollectVarUsage` utility treated a local function definition as in-scope after visiting the body of the local function. As a result, recursive calls from a local function were incorrectly identified as calls to an undefined variable. This commit updates the `CollectVarUsage` to treat a local function definition as in-scope when inspecting the function body. This change is similar to the change made for structural equality in #16756. * lint fixes
1 parent 32063b0 commit ed9aa56

File tree

2 files changed

+100
-2
lines changed

2 files changed

+100
-2
lines changed

src/relax/analysis/udchain.cc

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class UDChain : relax::ExprVisitor {
5555

5656
private:
5757
Map<Var, Expr> bound_values;
58+
std::unordered_set<Var> forward_declarations;
5859
std::unordered_map<Var, support::OrderedSet<Var>> usage_map;
5960
support::OrderedSet<Var> outputs;
6061

@@ -71,9 +72,20 @@ class UDChain : relax::ExprVisitor {
7172
cur_user_ = cache;
7273
}
7374

75+
void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) override {
76+
// A local Relax function may be recursively defined. References to
77+
// `binding->var` that appear within `func` are valid.
78+
DefineVar(binding->var);
79+
forward_declarations.insert(binding->var);
80+
ExprVisitor::VisitBinding_(binding, func);
81+
}
82+
7483
void VisitVarDef(const Var& var) override {
75-
CHECK(!usage_map.count(var)) << "Variable " << var << " was used before its definition";
76-
usage_map[var] = {};
84+
if (forward_declarations.count(var)) {
85+
forward_declarations.erase(var);
86+
} else {
87+
DefineVar(var);
88+
}
7789
}
7890
void VisitExpr_(const VarNode* op) override {
7991
auto var = GetRef<Var>(op);
@@ -89,6 +101,11 @@ class UDChain : relax::ExprVisitor {
89101
cur_user_ = nullptr;
90102
ExprVisitor::VisitExpr_(op);
91103
}
104+
105+
void DefineVar(const Var& var) {
106+
CHECK(!usage_map.count(var)) << "Variable " << var << " was used before its definition";
107+
usage_map[var] = {};
108+
}
92109
};
93110

94111
std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>> FunctionUseDef(

tests/python/relax/test_transform_dead_code_elimination.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,5 +658,86 @@ def subsubroutine(A: R.Tensor) -> R.Tensor:
658658
tvm.ir.assert_structural_equal(Expected, After)
659659

660660

661+
def test_recursively_defined_lambda():
662+
"""DCE may be applied to recursively-defined functions
663+
664+
While most expressions may only contain references to
665+
previously-defined variables, local Relax function definitions may
666+
contain references to themselves.
667+
668+
This is a regression test. In previous implementations, the
669+
recursive use of `while_loop` resulted in an error, as
670+
`while_loop` was not considered in-scope by the `CollectVarUsage`
671+
utility until after the body of `while_loop` had been visited.
672+
673+
"""
674+
675+
@I.ir_module
676+
class Before:
677+
@R.function
678+
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
679+
@R.function
680+
def while_loop(
681+
i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")
682+
) -> R.Tensor((2, 3), "float32"):
683+
cond = R.call_pure_packed(
684+
"test.vm.less", i, R.const(10), sinfo_args=R.Tensor((), dtype="bool")
685+
)
686+
c = R.const(1, dtype="int32")
687+
if cond:
688+
new_i = R.add(i, c)
689+
new_s = R.add(s, x)
690+
r = while_loop(new_i, new_s)
691+
else:
692+
r = s
693+
return r
694+
695+
gv = while_loop(R.const(0), x)
696+
return gv
697+
698+
Expected = Before
699+
700+
verify(Before, Expected)
701+
702+
703+
def test_recursively_defined_closure():
704+
"""DCE may be applied to recursively-defined closures
705+
706+
This test is identical to `test_recursively_defined_lambda`,
707+
except that the threshold for recursion is defined in an enclosed
708+
variable outside of the recursive function.
709+
710+
"""
711+
712+
@I.ir_module
713+
class Before:
714+
@R.function
715+
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
716+
threshold = R.const(10)
717+
718+
@R.function
719+
def while_loop(
720+
i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")
721+
) -> R.Tensor((2, 3), "float32"):
722+
cond = R.call_pure_packed(
723+
"test.vm.less", i, threshold, sinfo_args=R.Tensor((), dtype="bool")
724+
)
725+
c = R.const(1, dtype="int32")
726+
if cond:
727+
new_i = R.add(i, c)
728+
new_s = R.add(s, x)
729+
r = while_loop(new_i, new_s)
730+
else:
731+
r = s
732+
return r
733+
734+
gv = while_loop(R.const(0), x)
735+
return gv
736+
737+
Expected = Before
738+
739+
verify(Before, Expected)
740+
741+
661742
if __name__ == "__main__":
662743
tvm.testing.main()

0 commit comments

Comments
 (0)