@@ -425,17 +425,26 @@ def fake_tensor(source, source_index, packed_index):
425425 main_file .write ("\n " )
426426
427427
428- def _emit_main_compare (main_file , outputs , output_tolerance , mod_name , use_interface_c = False ):
428+ def _emit_main_compare (
429+ main_file ,
430+ outputs ,
431+ output_tolerance ,
432+ mod_name ,
433+ use_interface_c = False ,
434+ print_output_on_mismatch = False ,
435+ ):
429436 for key in outputs :
430437 sanitized_tensor_name = re .sub (r"\W" , "_" , key )
431438 expected_data_name = _mangle_name (mod_name , f"expected_output_data_{ sanitized_tensor_name } " )
432439 is_float_dtype = outputs [key ].dtype == "float32"
433440
434441 comparison_function = "abs"
435442 tolerance = output_tolerance or 0
443+ value_format_specifier = "%d"
436444 if is_float_dtype :
437445 comparison_function = "fabs"
438446 tolerance = output_tolerance or 0.001
447+ value_format_specifier = "%f"
439448
440449 data_length_var_name = (
441450 _mangle_name (mod_name , f"output_data_{ sanitized_tensor_name } " ) + "_len"
@@ -447,15 +456,34 @@ def _emit_main_compare(main_file, outputs, output_tolerance, mod_name, use_inter
447456 )
448457 else :
449458 actual_data_name = _mangle_name (mod_name , f"output_data_{ sanitized_tensor_name } " )
450- main_file .write (
451- f"for (int i = 0; i<{ data_length_var_name } ; i++) {{\n "
452- f"\t if ({ comparison_function } ({ actual_data_name } [i]-"
453- f"{ expected_data_name } [i]) > { tolerance } ) {{\n "
454- f'\t \t printf("{ AOT_FAILURE_TOKEN } \\ n");\n '
455- f"\t \t return -1;\n "
456- f"\t }}\n "
457- f"}}"
458- )
459+
460+ if print_output_on_mismatch :
461+ main_file .write (
462+ f"int mismatch = 0;"
463+ f'printf("Actual, Reference\\ n");\n '
464+ f"for (int i = 0; i<{ data_length_var_name } ; i++) {{\n "
465+ f"\t if ({ comparison_function } ({ actual_data_name } [i]-"
466+ f"{ expected_data_name } [i]) > { tolerance } ) {{\n "
467+ f'\t \t printf("{ value_format_specifier } , { value_format_specifier } \\ n"'
468+ f", { actual_data_name } [i], { expected_data_name } [i]);\n "
469+ f"\t \t mismatch = 1;\n "
470+ f"\t }}\n "
471+ f"}}"
472+ f"if (mismatch == 1) {{\n "
473+ f'\t printf("{ AOT_FAILURE_TOKEN } \\ n");\n '
474+ f"\t return -1;\n "
475+ f"}}"
476+ )
477+ else :
478+ main_file .write (
479+ f"for (int i = 0; i<{ data_length_var_name } ; i++) {{\n "
480+ f"\t if ({ comparison_function } ({ actual_data_name } [i]-"
481+ f"{ expected_data_name } [i]) > { tolerance } ) {{\n "
482+ f'\t \t printf("{ AOT_FAILURE_TOKEN } \\ n");\n '
483+ f"\t \t return -1;\n "
484+ f"\t }}\n "
485+ f"}}"
486+ )
459487
460488
461489def _emit_main_init_memory_manager (main_file ):
@@ -500,6 +528,7 @@ def _create_main(
500528 use_stack_allocator = True ,
501529 use_workspace_io = False ,
502530 debug_last_error = False ,
531+ print_output_on_mismatch = False ,
503532):
504533 file_path = pathlib .Path (f"{ output_path } /" + test_name ).resolve ()
505534 # create header file
@@ -568,7 +597,12 @@ def _create_main(
568597 for compiled_model in compiled_models :
569598 model = compiled_model .model
570599 _emit_main_compare (
571- main_file , model .outputs , model .output_tolerance , model .name , interface_api == "c"
600+ main_file ,
601+ model .outputs ,
602+ model .output_tolerance ,
603+ model .name ,
604+ interface_api == "c" ,
605+ print_output_on_mismatch ,
572606 )
573607 _emit_main_epilogue (main_file , custom_epilogue )
574608
@@ -709,6 +743,7 @@ def run_and_check(
709743 use_workspace_io : bool = False ,
710744 debug_last_error : bool = False ,
711745 checker : Optional [Callable [[str ], bool ]] = None ,
746+ print_output_on_mismatch : bool = False ,
712747):
713748 """
714749 This method uses the original test data and compiled runtime.Modules
@@ -789,6 +824,7 @@ def run_and_check_body(base_path):
789824 use_stack_allocator ,
790825 use_workspace_io ,
791826 debug_last_error ,
827+ print_output_on_mismatch ,
792828 )
793829
794830 if checker and (not checker (base_path )):
@@ -832,7 +868,10 @@ def run_and_check_body(base_path):
832868 _subprocess_check_log_output (run_command , build_path , run_log_path )
833869
834870 with open (run_log_path ) as run_log :
835- assert AOT_SUCCESS_TOKEN in run_log .read ()
871+ run_log_out = run_log .read ()
872+ if print_output_on_mismatch and AOT_FAILURE_TOKEN in run_log_out :
873+ print (run_log_out )
874+ assert AOT_SUCCESS_TOKEN in run_log_out
836875
837876 return True
838877
@@ -861,15 +900,21 @@ def compile_and_run(
861900 schedule_name : str = None ,
862901 debug_last_error : bool = False ,
863902 checker : Optional [Callable [[str ], bool ]] = None ,
903+ print_output_on_mismatch : bool = False ,
864904) -> bool :
865905 """This is a wrapper API to compile and run models as test for AoT
866906
867907 Parameters
868908 ----------
869909 test_dir : str
870- This path will contain build, codegen, include directories
871- verbose: bool
872- Prints commands to build and run AOT test runner
910+ This path will contain build, codegen, include directories.
911+
912+ verbose : bool
913+ Prints commands to build and run AOT test runner.
914+
915+ print_output_on_mismatch : bool
916+ Print both the output and reference values side-by-side
917+ when there is a mismatch.
873918 """
874919
875920 if target_opts :
@@ -904,6 +949,7 @@ def compile_and_run(
904949 verbose = verbose ,
905950 debug_last_error = debug_last_error ,
906951 checker = checker ,
952+ print_output_on_mismatch = print_output_on_mismatch ,
907953 )
908954
909955
0 commit comments