Skip to content

Commit 93006f8

Browse files
Anndrey24thaisacs
authored andcommitted
[AOT][Testing] Improve output mismatch information on test failure (apache#16765)
Enhanced AOT test harness to include overall mismatch percentage and the individual mismatch positions from the output tensor for debugging test failures. Both of these are still gated behind `print_output_on_mismatch == True`. I also added tests to check for the presence and correctness of this new debug information. Sample output: ``` Element [Position]: Actual, Reference ------------------------------------- Element [0, 8, 8, 7]: 521.846313, 521.847412 Element [0, 9, 8, 51]: 478.874359, 478.875549 Element [0, 9, 9, 6]: 462.901001, 462.899658 Mismatched elements: 3 / 16384 (0.02%) ... ```
1 parent 8411756 commit 93006f8

2 files changed

Lines changed: 85 additions & 15 deletions

File tree

python/tvm/testing/aot.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -476,20 +476,40 @@ def _emit_main_compare(
476476

477477
if print_output_on_mismatch:
478478
main_file.write(
479-
f"int mismatch = 0;"
480-
f'printf("Actual, Reference\\n");\n'
481-
f"for (int i = 0; i<{data_length_var_name}; i++) {{\n"
482-
f"\tif ({comparison_function}({actual_data_name}[i]-"
483-
f"{expected_data_name}[i]) > {tolerance}) {{\n"
484-
f'\t\tprintf("{value_format_specifier}, {value_format_specifier}\\n"'
485-
f", {actual_data_name}[i], {expected_data_name}[i]);\n"
486-
f"\t\tmismatch = 1;\n"
487-
f"\t}}\n"
488-
f"}}"
489-
f"if (mismatch == 1) {{\n"
490-
f'\tprintf("{AOT_FAILURE_TOKEN}\\n");\n'
491-
f"\treturn -1;\n"
492-
f"}}"
479+
f"""
480+
{{
481+
int mismatch = 0;
482+
int out_ndim = {outputs[key].ndim};
483+
int out_shape[] = {{{','.join(map(str, outputs[key].shape))}}};
484+
int out_indices[out_ndim];
485+
printf("Element [Position]: Actual, Reference\\n");
486+
printf("-------------------------------------\\n");
487+
for (int i = 0; i<{data_length_var_name}; i++) {{
488+
if ({comparison_function}({actual_data_name}[i] -
489+
{expected_data_name}[i]) > {tolerance}) {{
490+
int flat_index = i;
491+
for (int j = out_ndim - 1; j >= 0; j--){{
492+
out_indices[j] = flat_index % out_shape[j];
493+
flat_index /= out_shape[j];
494+
}}
495+
printf("Element [%d", out_indices[0]);
496+
for (int j = 1; j < out_ndim; j++)
497+
printf(", %d", out_indices[j]);
498+
printf("]: {value_format_specifier}, {value_format_specifier}\\n",
499+
{actual_data_name}[i], {expected_data_name}[i]);
500+
mismatch += 1;
501+
}}
502+
}}
503+
if (mismatch >= 1) {{
504+
float percent_mismatched =
505+
((float) mismatch) / ((float) {data_length_var_name}) * 100;
506+
printf("\\nMismatched elements: %d / %zu (%.2f%%)\\n",
507+
mismatch, {data_length_var_name}, percent_mismatched);
508+
printf("{AOT_FAILURE_TOKEN}\\n");
509+
return -1;
510+
}}
511+
}}
512+
"""
493513
)
494514
else:
495515
main_file.write(

tests/python/relay/aot/test_aot_test_harness.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,57 @@ def test_output_on_mismatch_option():
4646
).astype(dtype)
4747
}
4848

49-
msg = ".*Actual, Reference\n2.000000, 0.000000\nAOT_TEST_FAILURE.*"
49+
msg = ".*Actual, Reference(\n|.)*2.000000, 0.000000(\n|.)*AOT_TEST_FAILURE.*"
50+
with pytest.raises(RuntimeError, match=msg):
51+
compile_and_run(
52+
AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={}, outputs=outputs),
53+
test_runner,
54+
interface_api,
55+
use_unpacked_api,
56+
print_output_on_mismatch=True,
57+
)
58+
59+
60+
def test_output_position_on_mismatch():
61+
"""
62+
Test the mismatch position output for the print_output_on_mismatch option.
63+
"""
64+
interface_api = "packed"
65+
use_unpacked_api = True
66+
test_runner = AOTTestRunner()
67+
dtype = "float32"
68+
69+
x = np.zeros(shape=(2, 2), dtype=dtype)
70+
x[-1, -1] = 1
71+
func = relay.Function([], relay.const(x, dtype=dtype))
72+
outputs = {"output": np.zeros(shape=(2, 2), dtype=dtype)}
73+
74+
msg = ".*Element \\[1, 1\\]:.*"
75+
with pytest.raises(RuntimeError, match=msg):
76+
compile_and_run(
77+
AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={}, outputs=outputs),
78+
test_runner,
79+
interface_api,
80+
use_unpacked_api,
81+
print_output_on_mismatch=True,
82+
)
83+
84+
85+
def test_mismatch_percentage():
86+
"""
87+
Test the mismatch percentage for the print_output_on_mismatch option.
88+
"""
89+
interface_api = "packed"
90+
use_unpacked_api = True
91+
test_runner = AOTTestRunner()
92+
dtype = "float32"
93+
94+
x = np.zeros(shape=(8,), dtype=dtype)
95+
x[0] = 1
96+
func = relay.Function([], relay.const(x, dtype=dtype))
97+
outputs = {"output": np.zeros(shape=(8,), dtype=dtype)}
98+
99+
msg = ".*Mismatched elements: 1 / 8 \\(12.50%\\).*"
50100
with pytest.raises(RuntimeError, match=msg):
51101
compile_and_run(
52102
AOTTestModel(module=tvm.IRModule.from_expr(func), inputs={}, outputs=outputs),

0 commit comments

Comments
 (0)