|
16 | 16 | # under the License. |
17 | 17 | # pylint: disable=missing-docstring |
18 | 18 |
|
| 19 | +import functools |
19 | 20 | import numbers |
20 | 21 | from typing import Any, Optional, Tuple, Union |
21 | 22 |
|
22 | 23 | from tvm import relax, tir |
23 | | -from tvm.ir import Type |
24 | | -from tvm.relax import StructInfo, Expr |
| 24 | +from tvm.ir import Type, structural_equal |
| 25 | +from tvm.relax import Expr, StructInfo |
25 | 26 | from tvm.relax.utils import convert_to_expr |
26 | 27 | from tvm.script.ir_builder.relax.frame import BlockFrame |
27 | 28 |
|
|
32 | 33 | from .entry import MatchCastPair |
33 | 34 |
|
34 | 35 |
|
35 | | -def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: |
| 36 | +def bind_assign_value( |
| 37 | + self: Parser, |
| 38 | + node: doc.expr, |
| 39 | + var_name: str, |
| 40 | + value: Any, |
| 41 | + anno_sinfo: Optional[StructInfo] = None, |
| 42 | +) -> Any: |
36 | 43 | var_table = self.var_table.get() |
37 | 44 |
|
38 | 45 | if isinstance(value, tir.Var): |
@@ -64,19 +71,21 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - |
64 | 71 | value = convert_to_expr(value) |
65 | 72 | if isinstance(value, numbers.Number): |
66 | 73 | value = R.const(value) |
| 74 | + |
67 | 75 | if isinstance(value, relax.Expr): |
68 | | - var = R.emit(value) |
69 | | - # It's an internal check, so directly use assert here. |
70 | | - assert var is not None |
71 | | - IRBuilder.name(var_name, var) |
72 | | - return var |
| 76 | + var = R.emit(value, anno_sinfo) |
73 | 77 | elif isinstance(value, MatchCastPair): |
| 78 | + if anno_sinfo is not None and not structural_equal(anno_sinfo, value.struct_info): |
| 79 | + self.report_error( |
| 80 | + node, "Cannot specify inconsistent annotation for a match cast pair. " |
| 81 | + ) |
74 | 82 | var = R.emit_match_cast(value.value, value.struct_info) |
75 | | - IRBuilder.name(var_name, var) |
76 | | - return var |
77 | 83 | else: |
78 | 84 | raise TypeError(f"Unsupported type {type(value)} in assignment") |
79 | 85 |
|
| 86 | + IRBuilder.name(var_name, var) |
| 87 | + return var |
| 88 | + |
80 | 89 |
|
81 | 90 | # pylint: disable=inconsistent-return-statements |
82 | 91 | def eval_type_annotation( |
@@ -213,16 +222,13 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: |
213 | 222 | def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: |
214 | 223 | lhs = node.target |
215 | 224 | rhs = self.eval_expr(node.value) |
216 | | - ann_sinfo = self.visit_tvm_annotation(node.annotation) |
| 225 | + anno_sinfo = self.visit_tvm_annotation(node.annotation) |
217 | 226 | self.eval_assign( |
218 | 227 | target=lhs, |
219 | 228 | source=rhs, |
220 | | - bind_value=bind_assign_value, |
| 229 | + bind_value=functools.partial(bind_assign_value, anno_sinfo=anno_sinfo), |
221 | 230 | allow_shadowing=True, |
222 | 231 | ) |
223 | | - var = self.var_table.get().get(lhs.id) |
224 | | - assert isinstance(var, relax.Var) |
225 | | - R.ir.annotate_struct_info(var, ann_sinfo) |
226 | 232 |
|
227 | 233 |
|
228 | 234 | @dispatch.register(token="relax", type_name="Return") |
|
0 commit comments