|
47 | 47 | from .tir.node import Slice, BufferSlice |
48 | 48 | from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler |
49 | 49 | from .tir.special_stmt import SpecialStmt |
| 50 | +from .tir import ty |
50 | 51 |
|
51 | 52 |
|
52 | 53 | class CallArgumentReader(object): |
@@ -447,7 +448,9 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: |
447 | 448 | # add parameters of function |
448 | 449 | for arg in node.params: |
449 | 450 | # Note that this case is for T.match_buffer syntax sugar |
450 | | - if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)): |
| 451 | + if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance( |
| 452 | + self.transform(arg.ty.func_name), ty.GenericBufferType |
| 453 | + ): |
451 | 454 | result = self.handle_match_buffer_type(arg.ty, arg.name) |
452 | 455 | if not isinstance(result, buffer.Buffer): |
453 | 456 | self.report_error( |
@@ -1138,6 +1141,25 @@ def transform_TypeTuple(self, node): |
1138 | 1141 | """ |
1139 | 1142 | return [self.transform(value) for value in node.values] |
1140 | 1143 |
|
| 1144 | + def transform_TypeApply(self, node): |
| 1145 | + func = self.transform(node.func_name) |
| 1146 | + |
| 1147 | + if not isinstance(func, ty.TypeGeneric): |
| 1148 | + self.report_error(f"Expected a type but found {type(func).__name__}", node.span) |
| 1149 | + |
| 1150 | + param_types = [] |
| 1151 | + for param in node.params: |
| 1152 | + param_type = self.transform(param) |
| 1153 | + if not isinstance(param_type, ty.TypeGeneric): |
| 1154 | + self.report_error(f"Expected a type but found {type(param).__name__}", param.span) |
| 1155 | + |
| 1156 | + param_types.append(param_type) |
| 1157 | + |
| 1158 | + if len(param_types) == 1: |
| 1159 | + return func[param_types[0]] |
| 1160 | + else: |
| 1161 | + return func[param_types] |
| 1162 | + |
1141 | 1163 | def handle_match_buffer_type(self, node, buffer_name): |
1142 | 1164 | """special function to handle syntax sugar for match buffer. |
1143 | 1165 |
|
|
0 commit comments