Skip to content

Commit d62a4e9

Browse files
committed
[TVMScript] Avoid segfault from invalid TVMScript
Prior to this commit, after the `DiagnosticContext` prints its error, it overwrites the `DiagnosticRenderer` with a NULL renderer. If a second call to `DiagnosticContext::Render` occurs, it will segfault. This appears to be intended to prevent double-printing of error messages, but double-printing error messages is much worse than a segfault. In addition, `DiagnosticContext::Render` should only be called once. There's a common pattern in the parser where it will wrap exceptions in `DiagnosticError`, but re-raise exceptions that are already a `DiagnosticError`. This requires every such location to include `except DiagnosticError: raise`, and can easily be missed. This PR makes two changes: First, the `DiagnosticRenderer` is updated to have a no-op callback rather than a NULL callback. Second, the re-raising of `DiagnosticError` is moved to `Parser.report_error`, so that it does not need to be handled separately at several independent locations in the TVMScript parser.
1 parent 72b75fe commit d62a4e9

File tree

7 files changed

+35
-32
lines changed

7 files changed

+35
-32
lines changed

python/tvm/script/parser/core/evaluator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,8 @@ def _visit(self, node: doc.AST) -> Any:
267267
value = self._eval_slice(fields)
268268
else:
269269
value = self._eval_expr(node.__class__(**fields))
270-
except Exception as e: # pylint: disable=broad-except,invalid-name
271-
self.parser.report_error(node, e)
270+
except Exception as err: # pylint: disable=broad-except
271+
self.parser.report_error(node, err)
272272
return self._add_intermediate_result(value)
273273

274274
def _eval_lambda(self, node: doc.Lambda) -> Any:
@@ -286,8 +286,8 @@ def _eval_lambda(self, node: doc.Lambda) -> Any:
286286
"""
287287
try:
288288
value = self._eval_expr(node)
289-
except Exception as e: # pylint: disable=broad-except,invalid-name
290-
self.parser.report_error(node, str(e))
289+
except Exception as err: # pylint: disable=broad-except
290+
self.parser.report_error(node, err)
291291
return self._add_intermediate_result(value)
292292

293293
def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
@@ -463,9 +463,8 @@ def eval_assign(
463463
"""
464464
try:
465465
return _eval_assign(target, source)
466-
except Exception as e: # pylint: disable=broad-except,invalid-name
467-
parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
468-
raise
466+
except Exception as err: # pylint: disable=broad-except
467+
parser.report_error(target, err)
469468

470469

471470
def _eval_expr(

python/tvm/script/parser/core/parser.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,8 @@ def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
307307
def _wrapper(self: "Parser", node: doc.AST) -> None:
308308
try:
309309
return func(self, node)
310-
except DiagnosticError:
311-
raise
312-
except Exception as e: # pylint: disable=broad-except,invalid-name
313-
self.report_error(node, e)
314-
raise
310+
except Exception as err: # pylint: disable=broad-except
311+
self.report_error(node, err)
315312

316313
return _wrapper
317314

@@ -496,7 +493,6 @@ def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
496493
return self._duplicate_lhs_check(target.value)
497494
else:
498495
self.report_error(target, "Invalid type in assign statement")
499-
raise NotImplementedError
500496

501497
def eval_assign(
502498
self,
@@ -534,9 +530,7 @@ def eval_assign(
534530
self.var_table.add(k, var, allow_shadowing)
535531
return var_values
536532

537-
def report_error(
538-
self, node: doc.AST, err: Union[Exception, str]
539-
) -> None: # pylint: disable=no-self-use
533+
def report_error(self, node: doc.AST, err: Union[Exception, str]) -> None: # pylint: disable=no-self-use
540534
"""The error reporting when parsing.
541535
542536
Parameters
@@ -547,6 +541,12 @@ def report_error(
547541
err: Union[Exception, str]
548542
The error to report.
549543
"""
544+
545+
# If the error is already being raised as a DiagnosticError,
546+
# re-raise it without wrapping it in a DiagnosticContext.
547+
if isinstance(err, DiagnosticError):
548+
raise err
549+
550550
# Only take the last line of the error message
551551
if isinstance(err, TVMError):
552552
msg = list(filter(None, str(err).split("\n")))[-1]
@@ -595,11 +595,8 @@ def visit(self, node: doc.AST) -> None:
595595
raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
596596
try:
597597
func(node)
598-
except DiagnosticError:
599-
raise
600-
except Exception as e: # pylint: disable=broad-except,invalid-name
601-
self.report_error(node, str(e))
602-
raise
598+
except Exception as err: # pylint: disable=broad-except
599+
self.report_error(node, err)
603600

604601
def visit_body(self, node: List[doc.stmt]) -> Any:
605602
"""The general body visiting method.

python/tvm/script/parser/relax/parser.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy:
105105
annotation = self.eval_expr(node)
106106
return _normalize_struct_info_proxy(annotation)
107107
except Exception as err:
108-
self.report_error(node, str(err))
109-
raise err
108+
self.report_error(node, err)
110109

111110

112111
def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo:
@@ -116,7 +115,6 @@ def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> St
116115
return _normalize_struct_info(struct_info, var_table)
117116
except Exception as err:
118117
self.report_error(node, err)
119-
raise err
120118

121119

122120
def is_called(node: Any, func_name: str) -> bool:

python/tvm/script/parser/tir/parser.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) ->
6464
return value
6565
else:
6666
self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
67-
raise NotImplementedError
6867

6968

7069
def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
@@ -100,7 +99,6 @@ def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> A
10099
return value
101100
else:
102101
self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
103-
raise NotImplementedError
104102

105103

106104
def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:

src/ir/diagnostic.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ void DiagnosticContext::Render() {
127127
}
128128

129129
if (errs) {
130-
(*this)->renderer = DiagnosticRenderer();
130+
(*this)->renderer = DiagnosticRenderer([](DiagnosticContext) {});
131+
// (*this)->diagnostics.clear();
131132
LOG(FATAL) << "DiagnosticError: one or more error diagnostics were "
132133
<< "emitted, please check diagnostic render for output.";
133134
}

tests/python/relax/test_tvmscript_parser.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ def f(x: R.Tensor):
178178
R.add(x, x)
179179
return x
180180

181+
def test_incorrect_tensor_shape():
182+
with pytest.raises(tvm.error.DiagnosticError):
183+
184+
@R.function
185+
def f(x: R.Tensor([16])):
186+
y: R.Tensor(16) = R.add(x, x)
187+
return y
188+
181189

182190
def test_simple_module():
183191
@I.ir_module
@@ -1838,7 +1846,7 @@ def mul_add(x: R.Tensor) -> R.Tensor:
18381846
_check(InputModule, OutputModule)
18391847

18401848

1841-
def test_context_aware_parsing():
1849+
def test_context_aware_parsing(monkeypatch):
18421850
@tvm.script.ir_module
18431851
class Module:
18441852
@T.prim_func
@@ -1863,7 +1871,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32
18631871
def _break_env(self, *args):
18641872
raise RuntimeError("Fail to pass context-aware parsing")
18651873

1866-
tvm.ir.GlobalVar.__call__ = _break_env
1874+
monkeypatch.setattr(tvm.ir.GlobalVar, '__call__', _break_env)
18671875

18681876
_check(Module)
18691877

tests/python/tvmscript/test_tvmscript_printer_highlight.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tvm.testing
2222
from tvm import relay
2323
from tvm.script import tir as T
24-
from tvm.script.highlight import cprint
24+
from tvm.script.highlight import cprint, _format
2525

2626

2727
def test_highlight_script():
@@ -58,12 +58,14 @@ def test_cprint():
5858
# Print nodes with `script` method, e.g. PrimExpr
5959
cprint(tvm.tir.Var("v", "int32") + 1)
6060

61-
# Cannot print non-Python-style codes if black installed
61+
# Cannot print non-Python-style codes when using the black
62+
# formatter. This error comes from `_format`, used internally by
63+
# `cprint`, and doesn't occur when using the `ruff` formatter.
6264
try:
6365
import black
6466

6567
with pytest.raises(ValueError):
66-
cprint("if (a == 1) { a +=1; }")
68+
_format("if (a == 1) { a +=1; }", formatter="black")
6769
except ImportError:
6870
pass
6971

0 commit comments

Comments
 (0)