Skip to content

Commit 3d96623

Browse files
authored
Fix InternalError in StaticPlanBlockMemory when visiting DataflowBlockNode (#17501)
This PR fixes an internal error #17488 This error happens because the visitor class StorageAllocatorBaseVisitor does not correctly handle DataflowBlockNode instances. Specifically, the VisitBindingBlock_ method is not overridden for DataflowBlockNode, leading to an empty block_stack_ when it is expected to contain the current block. To fix this issue, we need to override the VisitBindingBlock_ method for const DataflowBlockNode* in the StorageAllocatorBaseVisitor class. By doing so, we ensure that the block_stack_ is correctly managed when visiting dataflow blocks, similar to how it is managed for regular binding blocks.
1 parent c7e9292 commit 3d96623

2 files changed

Lines changed: 50 additions & 0 deletions

File tree

src/relax/transform/static_plan_block_memory.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,15 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
314314
SetTokens(binding->var.get(), token_map_[binding->value.get()]);
315315
}
316316

317+
void VisitBindingBlock_(const DataflowBlockNode* block) override {
318+
// We maintain a block stack for token allocation-site and use-site check.
319+
block_stack_.push_back(block);
320+
ExprVisitor::VisitBindingBlock_(block);
321+
ICHECK(!block_stack_.empty());
322+
ICHECK(block_stack_.back() == block);
323+
block_stack_.pop_back();
324+
}
325+
317326
void VisitExpr_(const TupleNode* tuple) final {
318327
Array<Tokens> tokens;
319328
tokens.reserve(tuple->fields.size());

tests/python/relax/test_transform_static_plan_block_memory.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,5 +1504,46 @@ def main() -> R.Tensor((128,), dtype="float32"):
15041504
tvm.ir.assert_structural_equal(after, Expected)
15051505

15061506

1507+
def test_with_dataflow():
1508+
@I.ir_module
1509+
class Before:
1510+
@T.prim_func
1511+
def exp(A: T.handle, B: T.handle):
1512+
T.evaluate(0)
1513+
1514+
@R.function
1515+
def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
1516+
cls = Before
1517+
with R.dataflow():
1518+
alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(
1519+
R.shape([10]), R.dtype("float32"), runtime_device_index=0
1520+
)
1521+
_: R.Tuple() = cls.exp(x, alloc)
1522+
gv: R.Tensor((10,), dtype="float32") = alloc
1523+
R.output(gv)
1524+
return gv
1525+
1526+
@I.ir_module
1527+
class Expected:
1528+
@T.prim_func
1529+
def exp(A: T.handle, B: T.handle):
1530+
T.evaluate(0)
1531+
1532+
@R.function
1533+
def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
1534+
cls = Expected
1535+
with R.dataflow():
1536+
alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(
1537+
R.shape([10]), R.dtype("float32"), R.prim_value(0), R.str("global")
1538+
)
1539+
cls.exp(x, alloc)
1540+
gv: R.Tensor((10,), dtype="float32") = alloc
1541+
R.output(gv)
1542+
return gv
1543+
1544+
after = relax.transform.StaticPlanBlockMemory()(Before)
1545+
tvm.ir.assert_structural_equal(after, Expected)
1546+
1547+
15071548
if __name__ == "__main__":
15081549
tvm.testing.main()

0 commit comments

Comments
 (0)