Skip to content

Commit d18bfdf

Browse files
committed
Add TuningBudget class and modify runners to respect the budget
1 parent ec30052 commit d18bfdf

File tree

6 files changed

+258
-151
lines changed

6 files changed

+258
-151
lines changed

kernel_tuner/interface.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -620,11 +620,14 @@ def tune_kernel(
620620

621621
# copy some values from strategy_options
622622
searchspace_construction_options = {}
623+
max_fevals = None
624+
time_limit = None
625+
623626
if strategy_options:
624627
if "max_fevals" in strategy_options:
625-
tuning_options["max_fevals"] = strategy_options["max_fevals"]
628+
max_fevals = strategy_options["max_fevals"]
626629
if "time_limit" in strategy_options:
627-
tuning_options["time_limit"] = strategy_options["time_limit"]
630+
time_limit = strategy_options["time_limit"]
628631
if "searchspace_construction_options" in strategy_options:
629632
searchspace_construction_options = strategy_options["searchspace_construction_options"]
630633

@@ -703,14 +706,27 @@ def preprocess_cache(filepath):
703706
print(f"Searchspace has {searchspace.size} configurations after restrictions.")
704707

705708
# register the times and raise an exception if the budget is exceeded
706-
if "time_limit" in tuning_options:
707-
tuning_options["startup_time"] = perf_counter() - start_overhead_time
708-
if tuning_options["startup_time"] > tuning_options["time_limit"]:
709+
startup_time = perf_counter() - start_overhead_time
710+
711+
if time_limit is not None:
712+
if startup_time > time_limit:
709713
raise RuntimeError(
710-
f"The startup time of the tuning process ({tuning_options['startup_time']} seconds) has exceeded the time limit ({tuning_options['time_limit']} seconds). "
714+
f"The startup time of the tuning process ({startup_time} seconds) has exceeded the time limit ({time_limit} seconds). "
711715
"Please increase the time limit or decrease the size of the search space."
712716
)
713-
tuning_options["start_time"] = perf_counter()
717+
718+
time_limit -= startup_time
719+
720+
if max_fevals is None or max_fevals > searchspace.size:
721+
logging.info(f"evaluation limit has been adjusted from {max_fevals} to {searchspace.size} (search space size)")
722+
max_fevals = searchspace.size
723+
724+
# Create the budget. Add the time spent on startup to the budget
725+
budget = util.TuningBudget(time_limit, max_fevals)
726+
tuning_options["time_limit"] = time_limit # TODO: Is this used?
727+
tuning_options["max_fevals"] = max_fevals # TODO: Is this used?
728+
tuning_options["budget"] = budget
729+
714730

715731
# call the strategy to execute the tuning process
716732
results = strategy.tune(searchspace, runner, tuning_options)

kernel_tuner/runners/parallel.py

Lines changed: 67 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@
33
import logging
44
import socket
55
from time import perf_counter
6+
from typing import List, Optional
67
from kernel_tuner.core import DeviceInterface
78
from kernel_tuner.interface import Options
89
from kernel_tuner.runners.runner import Runner
9-
from kernel_tuner.util import ErrorConfig, print_config_output, process_metrics, store_cache
10+
from kernel_tuner.util import (
11+
BudgetExceededConfig,
12+
ErrorConfig,
13+
TuningBudget,
14+
print_config_output,
15+
process_metrics,
16+
store_cache,
17+
)
1018
from datetime import datetime, timezone
1119

1220
logger = logging.getLogger(__name__)
@@ -213,31 +221,31 @@ def shutdown(self):
213221
def available_parallelism(self):
214222
return len(self.workers)
215223

216-
def submit_jobs(self, jobs):
224+
def submit_jobs(self, jobs, budget: TuningBudget):
217225
pending_jobs = deque(jobs)
218226
running_jobs = []
219227

220-
while pending_jobs or running_jobs:
221-
should_wait = True
228+
while pending_jobs and not budget.is_done():
229+
job_was_submitted = False
222230

223231
# If there is still work left, submit it now
224-
if pending_jobs:
225-
for i, worker in enumerate(list(self.workers)):
226-
if worker.is_available():
227-
# Push worker to back of list
228-
self.workers.pop(i)
229-
self.workers.append(worker)
232+
for i, worker in enumerate(list(self.workers)):
233+
if worker.is_available():
234+
# Push worker to back of list
235+
self.workers.pop(i)
236+
self.workers.append(worker)
230237

231-
# Pop job and submit it
232-
job = pending_jobs.popleft()
233-
ref = worker.submit(*job)
234-
running_jobs.append(ref)
238+
# Pop job and submit it
239+
key, config = pending_jobs.popleft()
240+
ref = worker.submit(key, config)
241+
running_jobs.append(ref)
235242

236-
should_wait = False
237-
break
243+
job_was_submitted = True
244+
budget.add_evaluations(1)
245+
break
238246

239247
# If no work was submitted, wait until a worker is available
240-
if should_wait:
248+
if not job_was_submitted:
241249
if not running_jobs:
242250
raise RuntimeError("invalid state: no ray workers available")
243251

@@ -246,14 +254,28 @@ def submit_jobs(self, jobs):
246254
for result in ready_jobs:
247255
yield ray.get(result)
248256

249-
def run(self, parameter_space, tuning_options):
257+
# If there are still pending jobs, then the budget has been exceeded.
258+
# We return `None` to indicate that no result is available for these jobs.
259+
while pending_jobs:
260+
key, _ = pending_jobs.popleft()
261+
yield (key, None)
262+
263+
# Wait until running jobs complete
264+
while running_jobs:
265+
ready_jobs, running_jobs = ray.wait(running_jobs, num_returns=1)
266+
267+
for result in ready_jobs:
268+
yield ray.get(result)
269+
270+
def run(self, parameter_space, tuning_options) -> List[Optional[dict]]:
250271
metrics = tuning_options.metrics
251272
objective = tuning_options.objective
252273

253274
jobs = [] # Jobs that need to be executed
254275
results = [] # Results that will be returned at the end
255276
key2index = dict() # Used to insert job result back into `results`
256-
duplicate_entries = [] # Used for duplicate entries in `parameter_space`
277+
278+
total_worker_time = 0
257279

258280
# Select jobs which are not in the cache
259281
for index, config in enumerate(parameter_space):
@@ -262,28 +284,33 @@ def run(self, parameter_space, tuning_options):
262284

263285
if key in tuning_options.cache:
264286
params.update(tuning_options.cache[key])
265-
params["compile_time"] = 0
266-
params["verification_time"] = 0
267-
params["benchmark_time"] = 0
287+
288+
# Simulate compile, verification, and benchmark time
289+
tuning_options.budget.add_time_spent(params["compile_time"])
290+
tuning_options.budget.add_time_spent(params["verification_time"])
291+
tuning_options.budget.add_time_spent(params["benchmark_time"])
268292
results.append(params)
269293
else:
270-
if key not in key2index:
271-
key2index[key] = index
272-
else:
273-
duplicate_entries.append((key2index[key], index))
294+
assert key not in key2index, "duplicate jobs submitted"
295+
key2index[key] = index
274296

275297
jobs.append((key, params))
276298
results.append(None)
277299

278-
total_worker_time = 0
279300

280301
# Submit jobs and wait for them to finish
281-
for key, result in self.submit_jobs(jobs):
302+
for key, result in self.submit_jobs(jobs, tuning_options.budget):
303+
# `None` indicate that no result is available since the budget is exceeded.
304+
# We can skip it, meaning that `results` contains `None`s for these entries
305+
if result is None:
306+
continue
307+
308+
# Store the result into the output array
282309
results[key2index[key]] = result
283310

284311
# Collect total time spent by worker
285312
total_worker_time += (
286-
params["compile_time"] + params["verification_time"] + params["benchmark_time"]
313+
result["compile_time"] + result["verification_time"] + result["benchmark_time"]
287314
)
288315

289316
if isinstance(result.get(objective), ErrorConfig):
@@ -300,10 +327,6 @@ def run(self, parameter_space, tuning_options):
300327
# add configuration to cache
301328
store_cache(key, result, tuning_options.cachefile, tuning_options.cache)
302329

303-
# Copy each `i` to `j` for every `i,j` in `duplicate_entries`
304-
for i, j in duplicate_entries:
305-
results[j] = dict(results[i])
306-
307330
total_time = 1000 * (perf_counter() - self.start_time)
308331
self.start_time = perf_counter()
309332

@@ -313,14 +336,20 @@ def run(self, parameter_space, tuning_options):
313336
runner_time = total_time - strategy_time
314337
framework_time = max(runner_time * len(self.workers) - total_worker_time, 0)
315338

339+
num_valid_results = sum(bool(r) for r in results) # Count the number of valid results
340+
316341
# Post-process all the results
317-
for params in results:
342+
for result in results:
343+
# Skip missing results
344+
if not result:
345+
continue
346+
318347
# Amortize the time over all the results
319-
params["strategy_time"] = strategy_time / len(results)
320-
params["framework_time"] = framework_time / len(results)
348+
result["strategy_time"] = strategy_time / num_valid_results
349+
result["framework_time"] = framework_time / num_valid_results
321350

322351
# only compute metrics on configs that have not errored
323-
if metrics and not isinstance(params.get(objective), ErrorConfig):
324-
params = process_metrics(params, metrics)
352+
if not isinstance(result.get(objective), ErrorConfig):
353+
result = process_metrics(result, metrics)
325354

326355
return results

kernel_tuner/runners/sequential.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def run(self, parameter_space, tuning_options):
7070

7171
# iterate over parameter space
7272
for element in parameter_space:
73+
tuning_options.budget.add_evaluations(1)
7374
params = dict(zip(tuning_options.tune_params.keys(), element))
7475

7576
if stop_criterion_reached(tuning_options):
@@ -82,9 +83,11 @@ def run(self, parameter_space, tuning_options):
8283
x_int = ",".join([str(i) for i in element])
8384
if tuning_options.cache and x_int in tuning_options.cache:
8485
params.update(tuning_options.cache[x_int])
85-
params["compile_time"] = 0
86-
params["verification_time"] = 0
87-
params["benchmark_time"] = 0
86+
87+
# Simulate compile, verification, and benchmark time
88+
tuning_options.budget.add_time_spent(params["compile_time"])
89+
tuning_options.budget.add_time_spent(params["verification_time"])
90+
tuning_options.budget.add_time_spent(params["benchmark_time"])
8891
else:
8992
# attempt to warmup the GPU by running the first config in the parameter space and ignoring the result
9093
if not self.warmed_up:

kernel_tuner/runners/simulation.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob
5454
self.kernel_options = kernel_options
5555

5656
self.start_time = perf_counter()
57+
self.total_simulated_time = 0
5758
self.last_strategy_start_time = self.start_time
5859
self.last_strategy_time = 0
5960
self.units = {}
@@ -64,7 +65,7 @@ def get_device_info(self):
6465
def get_environment(self, tuning_options):
6566
env = self.dev.get_environment()
6667
env["simulation"] = True
67-
env["simulated_time"] = tuning_options.simulated_time
68+
env["simulated_time"] = self.total_simulated_time
6869
return env
6970

7071
def run(self, parameter_space, tuning_options):
@@ -89,55 +90,48 @@ def run(self, parameter_space, tuning_options):
8990
# iterate over parameter space
9091
for element in parameter_space:
9192

92-
if util.stop_criterion_reached(tuning_options):
93-
return results
94-
93+
# Append `None` to indicate that the tuning budget has been exceeded
94+
if tuning_options.budget.is_done():
95+
results.append(None)
96+
continue
97+
9598
# check if element is in the cache
96-
x_int = ",".join([str(i) for i in element])
97-
if tuning_options.cache and x_int in tuning_options.cache:
98-
result = tuning_options.cache[x_int].copy()
99+
key = ",".join([str(i) for i in element])
100+
101+
if key in tuning_options.cache:
102+
# Get from cache and create a copy
103+
result = dict(tuning_options.cache[key])
99104

100105
# only compute metrics on configs that have not errored
101106
if tuning_options.metrics and not isinstance(result.get(tuning_options.objective), util.ErrorConfig):
102107
result = util.process_metrics(result, tuning_options.metrics)
103108

104-
# Simulate behavior of sequential runner that when a configuration is
105-
# served from the cache by the sequential runner, the compile_time,
106-
# verification_time, and benchmark_time are set to 0.
107-
# This step is only performed in the simulation runner when a configuration
108-
# is served from the cache beyond the first timel. That is, when the
109-
# configuration is already counted towards the unique_results.
110-
# It is the responsibility of cost_func to add configs to unique_results.
111-
if x_int in tuning_options.unique_results:
112-
result["compile_time"] = 0
113-
result["verification_time"] = 0
114-
result["benchmark_time"] = 0
115-
116-
else:
117-
# configuration is evaluated for the first time, print to the console
118-
util.print_config_output(
119-
tuning_options.tune_params, result, self.quiet, tuning_options.metrics, self.units
120-
)
109+
# configuration is evaluated for the first time, print to the console
110+
util.print_config_output(
111+
tuning_options.tune_params, result, self.quiet, tuning_options.metrics, self.units
112+
)
121113

122114
# Everything but the strategy time and framework time are simulated,
123115
result["strategy_time"] = strategy_time_per_config
124116

117+
# Simulate the evaluation of this configuration
118+
tuning_options.budget.add_evaluations(1)
119+
tuning_options.budget.add_time_spent(result["compile_time"])
120+
tuning_options.budget.add_time_spent(result["verification_time"])
121+
tuning_options.budget.add_time_spent(result["benchmark_time"])
122+
125123
try:
126-
simulated_time = result["compile_time"] + result["verification_time"] + result["benchmark_time"]
127-
tuning_options.simulated_time += simulated_time
124+
self.total_simulated_time += result["compile_time"] + result["verification_time"] + result["benchmark_time"]
128125
except KeyError:
129-
if "time_limit" in tuning_options:
130-
raise RuntimeError(
131-
"Cannot use simulation mode with a time limit on a cache file that does not have full compile, verification, and benchmark timings on all configurations"
132-
)
126+
raise RuntimeError(
127+
"Cannot use simulation mode with a time limit on a cache file that does not have full compile, verification, and benchmark timings on all configurations"
128+
)
133129

134130
total_time = 1000 * (perf_counter() - self.start_time)
135131
self.start_time = perf_counter()
136132
result["framework_time"] = total_time
137133

138134
results.append(result)
139-
if x_int not in tuning_options.unique_results:
140-
tuning_options.unique_results[x_int] = result
141135
continue
142136

143137
# if the configuration is not in the cache and not within restrictions, simulate an InvalidConfig with warning

0 commit comments

Comments
 (0)