Skip to content

Commit 4557682

Browse files
authored
[Parser] Core Parser (apache#40)
1 parent 232a51d commit 4557682

14 files changed

Lines changed: 589 additions & 2 deletions

File tree

python/tvm/script/builder/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,5 @@
1616
# under the License.
1717
# pylint: disable=unused-import
1818
"""Namespace for the TVMScript Builder API."""
19-
20-
2119
from .builder import Builder, def_, def_many
2220
from .frame import Frame, IRModuleFrame

python/tvm/script/builder/tir/axis.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,7 @@ def reduce(dom, binding, dtype="int32") -> IterVar:
3636

3737
def remap(kinds, bindings, dtype="int32") -> IterVar:
3838
return _ffi_api.AxisRemap(kinds, bindings, dtype) # pylint: disable=no-member # type: ignore
39+
40+
41+
S = spatial
42+
R = reduce

python/tvm/script/builder/tir/prim_func_frame.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tvm.tir.buffer import Buffer
2222
from tvm.tir.expr import Var
2323

24+
from ..builder import Builder
2425
from . import _ffi_api
2526
from .base import TIRFrame
2627

@@ -36,3 +37,6 @@ def prim_func(name) -> PrimFuncFrame:
3637

3738
def arg(name, obj) -> Union[Var, Buffer]:
3839
return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore
40+
41+
42+
setattr(prim_func, "dispatch_token", "tir")
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the Licens.
17+
"""The parser"""
18+
from . import dispatch, parser, tir
19+
from .entry import parse
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""The dispatcher"""
18+
19+
import ast
20+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple
21+
22+
if TYPE_CHECKING:
23+
from .parser import Parser
24+
25+
26+
ParseMethod = Callable[
27+
["Parser", ast.AST],
28+
None,
29+
]
30+
31+
32+
class DispatchTable:
33+
"""Dispatch table for parse methods"""
34+
35+
_instance: Optional["DispatchTable"] = None
36+
table: Dict[Tuple[str, str], ParseMethod]
37+
38+
def __init__(self):
39+
self.table = {}
40+
41+
42+
DispatchTable._instance = DispatchTable() # pylint: disable=protected-access
43+
44+
45+
def register(
46+
token: str,
47+
type_name: str,
48+
):
49+
"""Register a method for a dispatch token and type name"""
50+
51+
def f(method: ParseMethod):
52+
DispatchTable._instance.table[ # pylint: disable=protected-access
53+
(token, type_name)
54+
] = method
55+
56+
return f
57+
58+
59+
def get(
60+
token: str,
61+
type_name: str,
62+
default: Optional[ParseMethod] = None,
63+
) -> Optional[ParseMethod]:
64+
return DispatchTable._instance.table.get( # pylint: disable=protected-access
65+
(token, type_name),
66+
default,
67+
)

python/tvm/script/parse/entry.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""The entry point of TVM parser."""
18+
import ast
19+
import inspect
20+
from typing import Any, Dict, Optional, Union
21+
22+
from ..builder import Builder
23+
from .parser import Parser
24+
25+
26+
class SourceCode:
27+
source_name: str
28+
start_line: int
29+
start_column: int
30+
source: str
31+
full_source: str
32+
33+
def __init__(self, program: Union[str, ast.AST]):
34+
if isinstance(program, str):
35+
self.source_name = "<str>"
36+
self.start_line = 1
37+
self.start_column = 0
38+
self.source = program
39+
self.full_source = program
40+
else:
41+
self.source_name = inspect.getsourcefile(program) # type: ignore
42+
lines, self.start_line = inspect.getsourcelines(program) # type: ignore
43+
44+
if lines:
45+
self.start_column = len(lines[0]) - len(lines[0].lstrip())
46+
else:
47+
self.start_column = 0
48+
if self.start_column and lines:
49+
self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
50+
else:
51+
self.source = ""
52+
try:
53+
# It will cause a problem when running in Jupyter Notebook.
54+
# `mod` will be <module '__main__'>, which is a built-in module
55+
# and `getsource` will throw a TypeError
56+
mod = inspect.getmodule(program)
57+
if mod:
58+
self.full_source = inspect.getsource(mod)
59+
else:
60+
self.full_source = self.source
61+
except TypeError:
62+
# It's a work around for Jupyter problem.
63+
# Since `findsource` is an internal API of inspect, we just use it
64+
# as a fallback method.
65+
src, _ = inspect.findsource(program) # type: ignore
66+
self.full_source = "".join(src)
67+
68+
def as_ast(self) -> ast.AST:
69+
return ast.parse(self.source)
70+
71+
72+
def parse(
73+
program: Union[ast.AST, Any, str],
74+
extra_vars: Optional[Dict[str, Any]] = None,
75+
):
76+
program_ast = SourceCode(program).as_ast()
77+
parser = Parser()
78+
with Builder() as builder:
79+
with parser.var_table.with_frame():
80+
if extra_vars:
81+
for k, v in extra_vars.items():
82+
parser.var_table.add(k, v)
83+
parser.visit(program_ast)
84+
return builder.get()
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""AST Evaluation"""
18+
import ast
19+
from typing import Any, Dict, Optional, Union
20+
21+
22+
def eval_expr(
23+
node: Union[ast.expr, ast.Expression],
24+
dict_globals: Optional[Dict[str, Any]],
25+
) -> Any:
26+
if isinstance(node, ast.expr):
27+
node = ast.Expression(body=node)
28+
assert isinstance(node, ast.Expression)
29+
if dict_globals is None:
30+
dict_globals = {}
31+
node = ast.fix_missing_locations(node)
32+
exe = compile(node, filename="<ast>", mode="eval")
33+
return eval(exe, dict_globals) # pylint: disable=eval-used
34+
35+
36+
def eval_assign(
37+
target: ast.expr,
38+
source: Any,
39+
) -> Dict[str, Any]:
40+
assert isinstance(target, ast.expr)
41+
RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name
42+
rhs_var_name = RHS_VAR_NAME
43+
dict_locals = {rhs_var_name: source}
44+
mod = ast.fix_missing_locations(
45+
ast.Module(
46+
body=[
47+
ast.Assign(
48+
targets=[target],
49+
value=ast.Name(
50+
id=rhs_var_name,
51+
ctx=ast.Load(),
52+
),
53+
)
54+
],
55+
type_ignores=[],
56+
)
57+
)
58+
exe = compile(mod, filename="<ast>", mode="exec")
59+
exec(exe, {}, dict_locals) # pylint: disable=exec-used
60+
del dict_locals[rhs_var_name]
61+
return dict_locals

python/tvm/script/parse/parser.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""The core parser"""
18+
import ast
19+
from typing import Any, Dict, List, Optional, Union
20+
21+
from ..builder import def_
22+
from . import dispatch
23+
from .evaluator import eval_assign, eval_expr
24+
from .utils import deferred
25+
from .var_table import VarTable
26+
27+
28+
def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
29+
for token in [self.dispatch_tokens[-1], "default"]:
30+
func = dispatch.get(token=token, type_name=type_name, default=None)
31+
if func is not None:
32+
return func
33+
return lambda self, node: self.generic_visit(node)
34+
35+
36+
def _handle_function(self: "Parser", node: ast.FunctionDef) -> None:
37+
if not node.decorator_list:
38+
self.report_error(node, "Function must be decorated")
39+
# TODO: only the last decorator is parsed
40+
decorator = self.eval_expr(node.decorator_list[-1])
41+
if hasattr(decorator, "dispatch_token"):
42+
token = decorator.dispatch_token
43+
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
44+
if func is not None:
45+
func(self, node)
46+
return
47+
self.report_error(node, "The parser does not understand the decorator")
48+
49+
50+
class Parser(ast.NodeVisitor):
51+
"""The TVMScript parser"""
52+
53+
dispatch_tokens: List[str]
54+
var_table: VarTable
55+
56+
def __init__(self) -> None:
57+
self.dispatch_tokens = ["default"]
58+
self.var_table = VarTable()
59+
60+
def with_dispatch_token(self, token: str):
61+
def pop_token():
62+
self.dispatch_tokens.pop()
63+
64+
self.dispatch_tokens.append(token)
65+
return deferred(pop_token)
66+
67+
def eval_expr(
68+
self,
69+
node: Union[ast.Expression, ast.expr],
70+
extra_vars: Optional[Dict[str, Any]] = None,
71+
) -> Any:
72+
var_values = self.var_table.get()
73+
if extra_vars is not None:
74+
for k, v in extra_vars.items():
75+
var_values[k] = v
76+
return eval_expr(node, var_values)
77+
78+
def eval_assign(
79+
self,
80+
target: ast.expr,
81+
source: Any,
82+
) -> Dict[str, Any]:
83+
var_values = eval_assign(target, source)
84+
for k, v in var_values.items():
85+
def_(k, v)
86+
self.var_table.add(k, v)
87+
return var_values
88+
89+
def report_error(self, node: ast.AST, msg: str) -> None: # pylint: disable=no-self-use
90+
raise SyntaxError(f"At {node.lineno}:{node.col_offset}: {msg}")
91+
92+
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name
93+
_handle_function(self, node)
94+
95+
def visit_body(self, node: List[ast.stmt]) -> Any:
96+
for stmt in node:
97+
self.visit(stmt)
98+
99+
def visit_arguments(self, node: ast.arguments) -> Any:
100+
_dispatch(self, "arguments")(self, node)
101+
102+
def visit_For(self, node: ast.For) -> Any: # pylint: disable=invalid-name
103+
_dispatch(self, "For")(self, node)
104+
105+
def visit_With(self, node: ast.With) -> Any: # pylint: disable=invalid-name
106+
_dispatch(self, "With")(self, node)
107+
108+
def visit_Assign(self, node: ast.Assign) -> Any: # pylint: disable=invalid-name
109+
_dispatch(self, "Assign")(self, node)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from . import tir

0 commit comments

Comments
 (0)