33import logging
44import socket
55from time import perf_counter
6+ from typing import List , Optional
67from kernel_tuner .core import DeviceInterface
78from kernel_tuner .interface import Options
89from kernel_tuner .runners .runner import Runner
9- from kernel_tuner .util import ErrorConfig , print_config_output , process_metrics , store_cache
10+ from kernel_tuner .util import (
11+ BudgetExceededConfig ,
12+ ErrorConfig ,
13+ TuningBudget ,
14+ print_config_output ,
15+ process_metrics ,
16+ store_cache ,
17+ )
1018from datetime import datetime , timezone
1119
1220logger = logging .getLogger (__name__ )
@@ -213,31 +221,31 @@ def shutdown(self):
213221 def available_parallelism (self ):
214222 return len (self .workers )
215223
216- def submit_jobs (self , jobs ):
224+ def submit_jobs (self , jobs , budget : TuningBudget ):
217225 pending_jobs = deque (jobs )
218226 running_jobs = []
219227
220- while pending_jobs or running_jobs :
221- should_wait = True
228+ while pending_jobs and not budget . is_done () :
229+ job_was_submitted = False
222230
223231 # If there is still work left, submit it now
224- if pending_jobs :
225- for i , worker in enumerate (list (self .workers )):
226- if worker .is_available ():
227- # Push worker to back of list
228- self .workers .pop (i )
229- self .workers .append (worker )
232+ for i , worker in enumerate (list (self .workers )):
233+ if worker .is_available ():
234+ # Push worker to back of list
235+ self .workers .pop (i )
236+ self .workers .append (worker )
230237
231- # Pop job and submit it
232- job = pending_jobs .popleft ()
233- ref = worker .submit (* job )
234- running_jobs .append (ref )
238+ # Pop job and submit it
239+ key , config = pending_jobs .popleft ()
240+ ref = worker .submit (key , config )
241+ running_jobs .append (ref )
235242
236- should_wait = False
237- break
243+ job_was_submitted = True
244+ budget .add_evaluations (1 )
245+ break
238246
239247 # If no work was submitted, wait until a worker is available
240- if should_wait :
248+ if not job_was_submitted :
241249 if not running_jobs :
242250 raise RuntimeError ("invalid state: no ray workers available" )
243251
@@ -246,14 +254,28 @@ def submit_jobs(self, jobs):
246254 for result in ready_jobs :
247255 yield ray .get (result )
248256
249- def run (self , parameter_space , tuning_options ):
257+ # If there are still pending jobs, then the budget has been exceeded.
258+ # We return `None` to indicate that no result is available for these jobs.
259+ while pending_jobs :
260+ key , _ = pending_jobs .popleft ()
261+ yield (key , None )
262+
263+ # Wait until running jobs complete
264+ while running_jobs :
265+ ready_jobs , running_jobs = ray .wait (running_jobs , num_returns = 1 )
266+
267+ for result in ready_jobs :
268+ yield ray .get (result )
269+
270+ def run (self , parameter_space , tuning_options ) -> List [Optional [dict ]]:
250271 metrics = tuning_options .metrics
251272 objective = tuning_options .objective
252273
253274 jobs = [] # Jobs that need to be executed
254275 results = [] # Results that will be returned at the end
255276 key2index = dict () # Used to insert job result back into `results`
256- duplicate_entries = [] # Used for duplicate entries in `parameter_space`
277+
278+ total_worker_time = 0
257279
258280 # Select jobs which are not in the cache
259281 for index , config in enumerate (parameter_space ):
@@ -262,28 +284,33 @@ def run(self, parameter_space, tuning_options):
262284
263285 if key in tuning_options .cache :
264286 params .update (tuning_options .cache [key ])
265- params ["compile_time" ] = 0
266- params ["verification_time" ] = 0
267- params ["benchmark_time" ] = 0
287+
288+ # Simulate compile, verification, and benchmark time
289+ tuning_options .budget .add_time_spent (params ["compile_time" ])
290+ tuning_options .budget .add_time_spent (params ["verification_time" ])
291+ tuning_options .budget .add_time_spent (params ["benchmark_time" ])
268292 results .append (params )
269293 else :
270- if key not in key2index :
271- key2index [key ] = index
272- else :
273- duplicate_entries .append ((key2index [key ], index ))
294+ assert key not in key2index , "duplicate jobs submitted"
295+ key2index [key ] = index
274296
275297 jobs .append ((key , params ))
276298 results .append (None )
277299
278- total_worker_time = 0
279300
280301 # Submit jobs and wait for them to finish
281- for key , result in self .submit_jobs (jobs ):
302+ for key , result in self .submit_jobs (jobs , tuning_options .budget ):
303+ # `None` indicate that no result is available since the budget is exceeded.
304+ # We can skip it, meaning that `results` contains `None`s for these entries
305+ if result is None :
306+ continue
307+
308+ # Store the result into the output array
282309 results [key2index [key ]] = result
283310
284311 # Collect total time spent by worker
285312 total_worker_time += (
286- params ["compile_time" ] + params ["verification_time" ] + params ["benchmark_time" ]
313+ result ["compile_time" ] + result ["verification_time" ] + result ["benchmark_time" ]
287314 )
288315
289316 if isinstance (result .get (objective ), ErrorConfig ):
@@ -300,10 +327,6 @@ def run(self, parameter_space, tuning_options):
300327 # add configuration to cache
301328 store_cache (key , result , tuning_options .cachefile , tuning_options .cache )
302329
303- # Copy each `i` to `j` for every `i,j` in `duplicate_entries`
304- for i , j in duplicate_entries :
305- results [j ] = dict (results [i ])
306-
307330 total_time = 1000 * (perf_counter () - self .start_time )
308331 self .start_time = perf_counter ()
309332
@@ -313,14 +336,20 @@ def run(self, parameter_space, tuning_options):
313336 runner_time = total_time - strategy_time
314337 framework_time = max (runner_time * len (self .workers ) - total_worker_time , 0 )
315338
339+ num_valid_results = sum (bool (r ) for r in results ) # Count the number of valid results
340+
316341 # Post-process all the results
317- for params in results :
342+ for result in results :
343+ # Skip missing results
344+ if not result :
345+ continue
346+
318347 # Amortize the time over all the results
319- params ["strategy_time" ] = strategy_time / len ( results )
320- params ["framework_time" ] = framework_time / len ( results )
348+ result ["strategy_time" ] = strategy_time / num_valid_results
349+ result ["framework_time" ] = framework_time / num_valid_results
321350
322351 # only compute metrics on configs that have not errored
323- if metrics and not isinstance (params .get (objective ), ErrorConfig ):
324- params = process_metrics (params , metrics )
352+ if not isinstance (result .get (objective ), ErrorConfig ):
353+ result = process_metrics (result , metrics )
325354
326355 return results
0 commit comments