@@ -530,9 +530,7 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack):
530
530
)
531
531
532
532
times = []
533
- logging .debug (f"candidate { candidate_id } benchmark_results: { benchmark_results } " )
534
533
for benchmark_result in benchmark_results :
535
- logging .debug (f"candidate { candidate_id } benchmark_result: { benchmark_result } " )
536
534
benchmark_name = benchmark_result .benchmark_name
537
535
# With multiple benchmark results, there will be `real_time_mean`, but
538
536
# not with single iteration benchmark results, so ignore the mean time
@@ -818,6 +816,63 @@ def compile(
818
816
return compiled_candidates
819
817
820
818
819
+ def select_best_benchmark_results (
820
+ candidate_results : list [BenchmarkResult ],
821
+ baseline_results : list [BenchmarkResult ],
822
+ num_candidates : Optional [int ],
823
+ ) -> list [BenchmarkResult ]:
824
+ filtered_candidate_results = [r for r in candidate_results if math .isfinite (r .time )]
825
+ if len (filtered_candidate_results ) == 0 :
826
+ logging .error ("No successful candidate benchmarks." )
827
+ return []
828
+ fallback_baseline_time : Optional [float ] = None
829
+ filtered_baseline_results : list [BenchmarkResult ] = []
830
+ for r in baseline_results :
831
+ if math .isfinite (r .time ):
832
+ filtered_baseline_results .append (r )
833
+ fallback_baseline_time = r .time
834
+ else :
835
+ logging .warning (f"Baseline on device { r .device_id } failed." )
836
+ if fallback_baseline_time is None :
837
+ logging .warning (
838
+ f"All baseline benchmarks failed. Baselines will not be used to select top candidates"
839
+ )
840
+ baseline_times_by_device = {}
841
+ for r in filtered_baseline_results :
842
+ baseline_times_by_device [r .device_id ] = r .time
843
+
844
+ # Select top candidates
845
+ def get_speedup (result : BenchmarkResult ) -> float :
846
+ if result .device_id in baseline_times_by_device :
847
+ return result .time / baseline_times_by_device [result .device_id ]
848
+ assert fallback_baseline_time is not None , "expected fallback_baseline_time"
849
+ return result .time / fallback_baseline_time
850
+
851
+ num_top_candidates = len (filtered_candidate_results )
852
+ if num_candidates is not None :
853
+ num_top_candidates = num_candidates
854
+
855
+ # Sort by the speedup over baseline on the same device. If a device failed
856
+ # the baseline benchmark, then use the fallback baseline. If there is no
857
+ # successful baseline, then the best we can do is to sort by the actual
858
+ # time.
859
+ sorting_key = get_speedup
860
+ if fallback_baseline_time is None :
861
+ sorting_key = lambda result : result .time
862
+ best_results = sorted (filtered_candidate_results , key = sorting_key )[
863
+ :num_top_candidates
864
+ ]
865
+ logging .info (f"Selected top[{ len (best_results )} ]:" )
866
+
867
+ for r in best_results :
868
+ if fallback_baseline_time is not None :
869
+ speedup = f"{ round (get_speedup (r ) * 100 , 2 )} % of baseline"
870
+ else :
871
+ speedup = "baseline unavailable"
872
+ logging .info (f"Candidate { r .candidate_id } time: { r .time } ({ speedup } )" )
873
+ return best_results
874
+
875
+
821
876
def benchmark (
822
877
args : argparse .Namespace ,
823
878
path_config : PathConfig ,
@@ -827,6 +882,9 @@ def benchmark(
827
882
num_candidates : Optional [int ] = None ,
828
883
):
829
884
logging .debug ("benchmark()" )
885
+ if len (compiled_candidates ) == 0 :
886
+ logging .warning ("No candidates to benchmark." )
887
+ return []
830
888
831
889
task_list = [
832
890
BenchmarkPack (
@@ -838,7 +896,7 @@ def benchmark(
838
896
if i != 0
839
897
]
840
898
worker_context_queue = create_worker_context_queue (args .devices )
841
- candidate_results = multiprocess_progress_wrapper (
899
+ candidate_results : list [ BenchmarkResult ] = multiprocess_progress_wrapper (
842
900
num_worker = len (args .devices ),
843
901
task_list = task_list ,
844
902
function = run_iree_benchmark_module_command ,
@@ -855,32 +913,19 @@ def benchmark(
855
913
candidate_tracker = candidate_trackers [0 ],
856
914
)
857
915
] * len (args .devices )
858
- baseline_results = multiprocess_progress_wrapper (
916
+ baseline_results : list [ BenchmarkResult ] = multiprocess_progress_wrapper (
859
917
num_worker = len (args .devices ),
860
918
task_list = baseline_task_list ,
861
919
function = run_iree_benchmark_module_command ,
862
920
initializer = init_worker_context ,
863
921
initializer_inputs = (worker_context_queue ,),
864
922
)
865
- baseline_times_by_device = {}
866
- for r in baseline_results :
867
- baseline_times_by_device [r .device_id ] = r .time
868
923
869
- # Select top candidates
870
- def get_speedup (result : BenchmarkResult ) -> float :
871
- return result .time / baseline_times_by_device [result .device_id ]
872
-
873
- num_top_candidates = len (candidate_results )
874
- if num_candidates is not None :
875
- num_top_candidates = num_candidates
876
- best_results = sorted (candidate_results , key = get_speedup )[:num_top_candidates ]
877
- logging .info (f"Selected top[{ len (best_results )} ]:" )
878
-
879
- for r in best_results :
880
- speedup = round (get_speedup (r ) * 100 , 2 )
881
- logging .info (
882
- f"Candidate { r .candidate_id } time: { r .time } ({ speedup } % of baseline)"
883
- )
924
+ best_results : list [BenchmarkResult ] = select_best_benchmark_results (
925
+ candidate_results = candidate_results ,
926
+ baseline_results = baseline_results ,
927
+ num_candidates = num_candidates ,
928
+ )
884
929
885
930
top_candidates = [result .candidate_id for result in best_results ]
886
931
return top_candidates
0 commit comments