Skip to content

Commit 35ad7d0

Browse files
authored
[tuner] Filter out non finite benchmark times (#799)
This PR fixes a bug where math.inf benchmark times can be selected as the top candidates. Any non finite times are now filtered out before selecting top candidates. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 500df58 commit 35ad7d0

File tree

2 files changed

+123
-22
lines changed

2 files changed

+123
-22
lines changed

tuner/tuner/libtuner.py

+67-22
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,7 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack):
530530
)
531531

532532
times = []
533-
logging.debug(f"candidate {candidate_id} benchmark_results: {benchmark_results}")
534533
for benchmark_result in benchmark_results:
535-
logging.debug(f"candidate {candidate_id} benchmark_result: {benchmark_result}")
536534
benchmark_name = benchmark_result.benchmark_name
537535
# With multiple benchmark results, there will be `real_time_mean`, but
538536
# not with single iteration benchmark results, so ignore the mean time
@@ -818,6 +816,63 @@ def compile(
818816
return compiled_candidates
819817

820818

819+
def select_best_benchmark_results(
820+
candidate_results: list[BenchmarkResult],
821+
baseline_results: list[BenchmarkResult],
822+
num_candidates: Optional[int],
823+
) -> list[BenchmarkResult]:
824+
filtered_candidate_results = [r for r in candidate_results if math.isfinite(r.time)]
825+
if len(filtered_candidate_results) == 0:
826+
logging.error("No successful candidate benchmarks.")
827+
return []
828+
fallback_baseline_time: Optional[float] = None
829+
filtered_baseline_results: list[BenchmarkResult] = []
830+
for r in baseline_results:
831+
if math.isfinite(r.time):
832+
filtered_baseline_results.append(r)
833+
fallback_baseline_time = r.time
834+
else:
835+
logging.warning(f"Baseline on device {r.device_id} failed.")
836+
if fallback_baseline_time is None:
837+
logging.warning(
838+
f"All baseline benchmarks failed. Baselines will not be used to select top candidates"
839+
)
840+
baseline_times_by_device = {}
841+
for r in filtered_baseline_results:
842+
baseline_times_by_device[r.device_id] = r.time
843+
844+
# Select top candidates
845+
def get_speedup(result: BenchmarkResult) -> float:
846+
if result.device_id in baseline_times_by_device:
847+
return result.time / baseline_times_by_device[result.device_id]
848+
assert fallback_baseline_time is not None, "expected fallback_baseline_time"
849+
return result.time / fallback_baseline_time
850+
851+
num_top_candidates = len(filtered_candidate_results)
852+
if num_candidates is not None:
853+
num_top_candidates = num_candidates
854+
855+
# Sort by the speedup over baseline on the same device. If a device failed
856+
# the baseline benchmark, then use the fallback baseline. If there is no
857+
# successful baseline, then the best we can do is to sort by the actual
858+
# time.
859+
sorting_key = get_speedup
860+
if fallback_baseline_time is None:
861+
sorting_key = lambda result: result.time
862+
best_results = sorted(filtered_candidate_results, key=sorting_key)[
863+
:num_top_candidates
864+
]
865+
logging.info(f"Selected top[{len(best_results)}]:")
866+
867+
for r in best_results:
868+
if fallback_baseline_time is not None:
869+
speedup = f"{round(get_speedup(r) * 100, 2)}% of baseline"
870+
else:
871+
speedup = "baseline unavailable"
872+
logging.info(f"Candidate {r.candidate_id} time: {r.time} ({speedup})")
873+
return best_results
874+
875+
821876
def benchmark(
822877
args: argparse.Namespace,
823878
path_config: PathConfig,
@@ -827,6 +882,9 @@ def benchmark(
827882
num_candidates: Optional[int] = None,
828883
):
829884
logging.debug("benchmark()")
885+
if len(compiled_candidates) == 0:
886+
logging.warning("No candidates to benchmark.")
887+
return []
830888

831889
task_list = [
832890
BenchmarkPack(
@@ -838,7 +896,7 @@ def benchmark(
838896
if i != 0
839897
]
840898
worker_context_queue = create_worker_context_queue(args.devices)
841-
candidate_results = multiprocess_progress_wrapper(
899+
candidate_results: list[BenchmarkResult] = multiprocess_progress_wrapper(
842900
num_worker=len(args.devices),
843901
task_list=task_list,
844902
function=run_iree_benchmark_module_command,
@@ -855,32 +913,19 @@ def benchmark(
855913
candidate_tracker=candidate_trackers[0],
856914
)
857915
] * len(args.devices)
858-
baseline_results = multiprocess_progress_wrapper(
916+
baseline_results: list[BenchmarkResult] = multiprocess_progress_wrapper(
859917
num_worker=len(args.devices),
860918
task_list=baseline_task_list,
861919
function=run_iree_benchmark_module_command,
862920
initializer=init_worker_context,
863921
initializer_inputs=(worker_context_queue,),
864922
)
865-
baseline_times_by_device = {}
866-
for r in baseline_results:
867-
baseline_times_by_device[r.device_id] = r.time
868923

869-
# Select top candidates
870-
def get_speedup(result: BenchmarkResult) -> float:
871-
return result.time / baseline_times_by_device[result.device_id]
872-
873-
num_top_candidates = len(candidate_results)
874-
if num_candidates is not None:
875-
num_top_candidates = num_candidates
876-
best_results = sorted(candidate_results, key=get_speedup)[:num_top_candidates]
877-
logging.info(f"Selected top[{len(best_results)}]:")
878-
879-
for r in best_results:
880-
speedup = round(get_speedup(r) * 100, 2)
881-
logging.info(
882-
f"Candidate {r.candidate_id} time: {r.time} ({speedup}% of baseline)"
883-
)
924+
best_results: list[BenchmarkResult] = select_best_benchmark_results(
925+
candidate_results=candidate_results,
926+
baseline_results=baseline_results,
927+
num_candidates=num_candidates,
928+
)
884929

885930
top_candidates = [result.candidate_id for result in best_results]
886931
return top_candidates

tuner/tuner/libtuner_test.py

+56
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import argparse
8+
import math
89
import pytest
910
import json
1011
from subprocess import CompletedProcess
@@ -175,5 +176,60 @@ def test_validate_devices_with_invalid_device() -> None:
175176
assert expected_call in mock_handle_error.call_args_list
176177

177178

179+
def test_select_best_benchmark_results() -> None:
180+
candidate_results = [
181+
libtuner.BenchmarkResult(1, 0.5, "hip://0"),
182+
libtuner.BenchmarkResult(2, 0.3, "hip://1"),
183+
libtuner.BenchmarkResult(3, 0.2, "hip://2"),
184+
libtuner.BenchmarkResult(4, 0.1, "hip://3"),
185+
]
186+
baseline_results = [
187+
libtuner.BenchmarkResult(0, 1.0, "hip://0"),
188+
libtuner.BenchmarkResult(0, 0.1, "hip://1"),
189+
libtuner.BenchmarkResult(0, 0.1, "hip://2"),
190+
libtuner.BenchmarkResult(0, 0.1, "hip://3"),
191+
]
192+
best_results: list[
193+
libtuner.BenchmarkResult
194+
] = libtuner.select_best_benchmark_results(
195+
candidate_results=candidate_results,
196+
baseline_results=baseline_results,
197+
num_candidates=3,
198+
)
199+
assert best_results[0].candidate_id == 1
200+
assert best_results[1].candidate_id == 4
201+
assert best_results[2].candidate_id == 3
202+
203+
baseline_results = [
204+
libtuner.BenchmarkResult(0, math.inf, "hip://0"),
205+
libtuner.BenchmarkResult(0, 0.1, "hip://1"),
206+
libtuner.BenchmarkResult(0, 0.1, "hip://2"),
207+
libtuner.BenchmarkResult(0, 0.1, "hip://3"),
208+
]
209+
best_results = libtuner.select_best_benchmark_results(
210+
candidate_results=candidate_results,
211+
baseline_results=baseline_results,
212+
num_candidates=3,
213+
)
214+
assert best_results[0].candidate_id == 4
215+
assert best_results[1].candidate_id == 3
216+
assert best_results[2].candidate_id == 2
217+
218+
baseline_results = [
219+
libtuner.BenchmarkResult(0, math.inf, "hip://0"),
220+
libtuner.BenchmarkResult(0, math.inf, "hip://1"),
221+
libtuner.BenchmarkResult(0, math.inf, "hip://2"),
222+
libtuner.BenchmarkResult(0, math.inf, "hip://3"),
223+
]
224+
best_results = libtuner.select_best_benchmark_results(
225+
candidate_results=candidate_results,
226+
baseline_results=baseline_results,
227+
num_candidates=3,
228+
)
229+
assert best_results[0].candidate_id == 4
230+
assert best_results[1].candidate_id == 3
231+
assert best_results[2].candidate_id == 2
232+
233+
178234
def test_enum_collision():
179235
from iree.compiler.dialects import linalg, vector, iree_gpu, iree_codegen, iree_input # type: ignore

0 commit comments

Comments
 (0)