Skip to content

Commit ae81656

Browse files
MasterJH5574thaisacs
authored andcommitted
[Fix] Fix the purity flag of "vm.call_tir_dyn" and "kill" ops (apache#16773)
This PR fixes the purity flag of `relax.vm.call_tir_dyn` and another few "kill" ops. Their purity flags were set to True, which made them possible to be removed by `remove_all_unused`. * `relax.vm.call_tir_dyn` works by mutating the input args in place, which is not pure. * though the "kill" ops have no actions so far, their semantics suggest that they are impure. A regression test is added to prevent the unexpected removal from happening again.
1 parent 2d45323 commit ae81656

4 files changed

Lines changed: 44 additions & 17 deletions

File tree

src/relax/op/op.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -921,8 +921,8 @@ RELAY_REGISTER_OP("relax.memory.kill_storage")
921921
.set_num_inputs(1)
922922
.add_argument("storage", "Expr", "The storage to be killed.")
923923
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
924-
// deallocation also isn't considered a "visible effect" as far as purity is concerned
925-
.set_attr<Bool>("FPurity", Bool(true));
924+
// We mark this as impure so it wouldn't be removed by "remove_all_unused"
925+
.set_attr<Bool>("FPurity", Bool(false));
926926

927927
Expr MakeMemKillStorage(Expr storage) {
928928
static const Op& op = Op::Get("relax.memory.kill_storage");
@@ -937,8 +937,8 @@ RELAY_REGISTER_OP("relax.memory.kill_tensor")
937937
.set_num_inputs(1)
938938
.add_argument("tensor", "Expr", "The tensor to be killed.")
939939
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
940-
// memory deallocation also isn't considered a "visible effect" as far as purity is concerned
941-
.set_attr<Bool>("FPurity", Bool(true));
940+
// We mark this as impure so it wouldn't be removed by "remove_all_unused"
941+
.set_attr<Bool>("FPurity", Bool(false));
942942

943943
Expr MakeMemKillTensor(Expr tensor) {
944944
static const Op& op = Op::Get("relax.memory.kill_tensor");
@@ -1013,8 +1013,8 @@ TVM_REGISTER_OP("relax.vm.kill_object")
10131013
.set_num_inputs(1)
10141014
.add_argument("obj", "Expr", "The object to be killed.")
10151015
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
1016-
// deallocation also isn't considered a "visible effect" as far as purity is concerned
1017-
.set_attr<Bool>("FPurity", Bool(true));
1016+
// We mark this as impure so it wouldn't be removed by "remove_all_unused"
1017+
.set_attr<Bool>("FPurity", Bool(false));
10181018

10191019
Expr MakeVMKillObject(Expr obj) {
10201020
static const Op& op = Op::Get("relax.vm.kill_object");
@@ -1031,7 +1031,8 @@ RELAY_REGISTER_OP("relax.vm.call_tir_dyn")
10311031
.add_argument("args", "Tuple",
10321032
"The input arguments (list of tensors and last argument is ShapeExpr)")
10331033
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
1034-
.set_attr<Bool>("FPurity", Bool(true));
1034+
// "relax.vm.call_tir_dyn" works in an in-place way, which is impure.
1035+
.set_attr<Bool>("FPurity", Bool(false));
10351036

10361037
Expr MakeCallTIRDyn(Expr func, Tuple args) {
10371038
static const Op& op = Op::Get("relax.vm.call_tir_dyn");

tests/python/relax/test_analysis.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,21 @@
1919

2020
import tvm
2121
import tvm.testing
22-
from tvm import tir
2322
from tvm import relax as rx
23+
from tvm import tir
2424
from tvm.relax.analysis import (
25-
has_reshape_pattern,
26-
udchain,
27-
remove_all_unused,
28-
name_to_binding,
29-
all_vars,
3025
all_global_vars,
31-
free_vars,
26+
all_vars,
3227
bound_vars,
28+
free_vars,
29+
has_reshape_pattern,
30+
name_to_binding,
31+
remove_all_unused,
32+
udchain,
3333
)
34-
from tvm.script import relax as R, tir as T
34+
from tvm.script import ir as I
35+
from tvm.script import relax as R
36+
from tvm.script import tir as T
3537

3638

3739
def var_name_set(vars: List[Union[rx.Var, rx.GlobalVar]]) -> Set[str]:
@@ -352,6 +354,30 @@ def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
352354
tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)
353355

354356

357+
def test_retain_calls_to_impure_builtin_ops():
358+
@I.ir_module
359+
class Module:
360+
@T.prim_func(private=True)
361+
def my_tir(A: T.handle, B: T.handle, n: T.int64):
362+
T.evaluate(0)
363+
364+
@R.function(pure=False)
365+
def main(x: R.Tensor(("n",), "float32")):
366+
cls = Module
367+
n = T.int64()
368+
storage = R.memory.alloc_storage((n * 4,), 0, "global", "float32")
369+
alloc = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), "float32")
370+
# "call_tir_dyn" is impure which shouldn't be removed.
371+
R.vm.call_tir_dyn(cls.my_tir, (x, alloc, R.shape([n])))
372+
# "kill_tensor"/"kill_storage" are impure which shouldn't be removed.
373+
R.memory.kill_tensor(alloc)
374+
R.memory.kill_storage(storage)
375+
return x
376+
377+
after = remove_all_unused(Module["main"])
378+
tvm.ir.assert_structural_equal(after, Module["main"], map_free_vars=True)
379+
380+
355381
def test_name_to_binding_var_shadowing():
356382
@R.function
357383
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:

tests/python/relax/test_transform_cse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def sum(
435435
def test_do_not_eliminate_dtype():
436436
@I.ir_module
437437
class Before:
438-
@R.function
438+
@R.function(pure=False)
439439
def foo() -> R.Tensor((32, 64), "int32"):
440440
obj: R.Object = R.vm.alloc_storage(
441441
R.shape([24576]), runtime_device_index=0, dtype="uint8"

tests/python/relax/test_tvmscript_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1552,7 +1552,7 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")):
15521552

15531553

15541554
def test_vm_ops():
1555-
@R.function
1555+
@R.function(pure=False)
15561556
def foo(x: R.Tensor(("m", "n"), dtype="float32")):
15571557
m = T.int64()
15581558
n = T.int64()

0 commit comments

Comments
 (0)