Skip to content

Commit b3fa6cb

Browse files
authored
[AOT][Testing] Print output values on test failure (#16611)
This commit enhances the AOT test harness to print the "actual" and "reference" values when there is a mismatch. This helps when debugging a failing test. Sample output: ``` Actual, Reference 8.502946, 8.887751 9.810405, 9.108611 8.563767, 9.041000 10.019511, 9.190888 .... ```
1 parent 3ec0ca5 commit b3fa6cb

3 files changed

Lines changed: 123 additions & 15 deletions

File tree

python/tvm/testing/aot.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -425,17 +425,26 @@ def fake_tensor(source, source_index, packed_index):
425425
main_file.write("\n")
426426

427427

428-
def _emit_main_compare(main_file, outputs, output_tolerance, mod_name, use_interface_c=False):
428+
def _emit_main_compare(
429+
main_file,
430+
outputs,
431+
output_tolerance,
432+
mod_name,
433+
use_interface_c=False,
434+
print_output_on_mismatch=False,
435+
):
429436
for key in outputs:
430437
sanitized_tensor_name = re.sub(r"\W", "_", key)
431438
expected_data_name = _mangle_name(mod_name, f"expected_output_data_{sanitized_tensor_name}")
432439
is_float_dtype = outputs[key].dtype == "float32"
433440

434441
comparison_function = "abs"
435442
tolerance = output_tolerance or 0
443+
value_format_specifier = "%d"
436444
if is_float_dtype:
437445
comparison_function = "fabs"
438446
tolerance = output_tolerance or 0.001
447+
value_format_specifier = "%f"
439448

440449
data_length_var_name = (
441450
_mangle_name(mod_name, f"output_data_{sanitized_tensor_name}") + "_len"
@@ -447,15 +456,34 @@ def _emit_main_compare(main_file, outputs, output_tolerance, mod_name, use_inter
447456
)
448457
else:
449458
actual_data_name = _mangle_name(mod_name, f"output_data_{sanitized_tensor_name}")
450-
main_file.write(
451-
f"for (int i = 0; i<{data_length_var_name}; i++) {{\n"
452-
f"\tif ({comparison_function}({actual_data_name}[i]-"
453-
f"{expected_data_name}[i]) > {tolerance}) {{\n"
454-
f'\t\tprintf("{AOT_FAILURE_TOKEN}\\n");\n'
455-
f"\t\treturn -1;\n"
456-
f"\t}}\n"
457-
f"}}"
458-
)
459+
460+
if print_output_on_mismatch:
461+
main_file.write(
462+
f"int mismatch = 0;"
463+
f'printf("Actual, Reference\\n");\n'
464+
f"for (int i = 0; i<{data_length_var_name}; i++) {{\n"
465+
f"\tif ({comparison_function}({actual_data_name}[i]-"
466+
f"{expected_data_name}[i]) > {tolerance}) {{\n"
467+
f'\t\tprintf("{value_format_specifier}, {value_format_specifier}\\n"'
468+
f", {actual_data_name}[i], {expected_data_name}[i]);\n"
469+
f"\t\tmismatch = 1;\n"
470+
f"\t}}\n"
471+
f"}}"
472+
f"if (mismatch == 1) {{\n"
473+
f'\tprintf("{AOT_FAILURE_TOKEN}\\n");\n'
474+
f"\treturn -1;\n"
475+
f"}}"
476+
)
477+
else:
478+
main_file.write(
479+
f"for (int i = 0; i<{data_length_var_name}; i++) {{\n"
480+
f"\tif ({comparison_function}({actual_data_name}[i]-"
481+
f"{expected_data_name}[i]) > {tolerance}) {{\n"
482+
f'\t\tprintf("{AOT_FAILURE_TOKEN}\\n");\n'
483+
f"\t\treturn -1;\n"
484+
f"\t}}\n"
485+
f"}}"
486+
)
459487

460488

461489
def _emit_main_init_memory_manager(main_file):
@@ -500,6 +528,7 @@ def _create_main(
500528
use_stack_allocator=True,
501529
use_workspace_io=False,
502530
debug_last_error=False,
531+
print_output_on_mismatch=False,
503532
):
504533
file_path = pathlib.Path(f"{output_path}/" + test_name).resolve()
505534
# create header file
@@ -568,7 +597,12 @@ def _create_main(
568597
for compiled_model in compiled_models:
569598
model = compiled_model.model
570599
_emit_main_compare(
571-
main_file, model.outputs, model.output_tolerance, model.name, interface_api == "c"
600+
main_file,
601+
model.outputs,
602+
model.output_tolerance,
603+
model.name,
604+
interface_api == "c",
605+
print_output_on_mismatch,
572606
)
573607
_emit_main_epilogue(main_file, custom_epilogue)
574608

@@ -709,6 +743,7 @@ def run_and_check(
709743
use_workspace_io: bool = False,
710744
debug_last_error: bool = False,
711745
checker: Optional[Callable[[str], bool]] = None,
746+
print_output_on_mismatch: bool = False,
712747
):
713748
"""
714749
This method uses the original test data and compiled runtime.Modules
@@ -789,6 +824,7 @@ def run_and_check_body(base_path):
789824
use_stack_allocator,
790825
use_workspace_io,
791826
debug_last_error,
827+
print_output_on_mismatch,
792828
)
793829

794830
if checker and (not checker(base_path)):
@@ -832,7 +868,10 @@ def run_and_check_body(base_path):
832868
_subprocess_check_log_output(run_command, build_path, run_log_path)
833869

834870
with open(run_log_path) as run_log:
835-
assert AOT_SUCCESS_TOKEN in run_log.read()
871+
run_log_out = run_log.read()
872+
if print_output_on_mismatch and AOT_FAILURE_TOKEN in run_log_out:
873+
print(run_log_out)
874+
assert AOT_SUCCESS_TOKEN in run_log_out
836875

837876
return True
838877

@@ -861,15 +900,21 @@ def compile_and_run(
861900
schedule_name: str = None,
862901
debug_last_error: bool = False,
863902
checker: Optional[Callable[[str], bool]] = None,
903+
print_output_on_mismatch: bool = False,
864904
) -> bool:
865905
"""This is a wrapper API to compile and run models as test for AoT
866906
867907
Parameters
868908
----------
869909
test_dir : str
870-
This path will contain build, codegen, include directories
871-
verbose: bool
872-
Prints commands to build and run AOT test runner
910+
This path will contain build, codegen, include directories.
911+
912+
verbose : bool
913+
Prints commands to build and run AOT test runner.
914+
915+
print_output_on_mismatch : bool
916+
Print both the output and reference values side-by-side
917+
when there is a mismatch.
873918
"""
874919

875920
if target_opts:
@@ -904,6 +949,7 @@ def compile_and_run(
904949
verbose=verbose,
905950
debug_last_error=debug_last_error,
906951
checker=checker,
952+
print_output_on_mismatch=print_output_on_mismatch,
907953
)
908954

909955

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
Tests for the AOT test harness.
20+
"""
21+
22+
import pytest
23+
import numpy as np
24+
25+
import tvm
26+
from tvm import relay
27+
from tvm.testing.aot import AOTTestRunner, compile_and_run, AOTTestModel
28+
29+
30+
def test_output_on_mismatch_option():
31+
"""
32+
Test the print_output_on_mismatch option when there is a mismatch.
33+
"""
34+
interface_api = "packed"
35+
use_unpacked_api = True
36+
test_runner = AOTTestRunner()
37+
dtype = "float32"
38+
39+
two = relay.add(relay.const(1, dtype=dtype), relay.const(1, dtype=dtype))
40+
func = relay.Function([], two)
41+
outputs = {
42+
"output": np.array(
43+
[
44+
0,
45+
]
46+
).astype(dtype)
47+
}
48+
49+
msg = ".*Actual, Reference\n2.000000, 0.000000\nAOT_TEST_FAILURE.*"
50+
with pytest.raises(RuntimeError, match=msg):
51+
compile_and_run(
52+
AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={}, outputs=outputs),
53+
test_runner,
54+
interface_api,
55+
use_unpacked_api,
56+
print_output_on_mismatch=True,
57+
)
58+
59+
60+
if __name__ == "__main__":
61+
tvm.testing.main()

tests/python/relay/aot/test_crt_aot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def test_conv_with_params(interface_api, use_unpacked_api, test_runner):
9393
test_runner,
9494
interface_api,
9595
use_unpacked_api,
96+
print_output_on_mismatch=True,
9697
)
9798

9899

0 commit comments

Comments
 (0)