@@ -146,6 +146,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
146
146
the point is present in ``runner.failed``.
147
147
raise_if_retries_exceeded : bool, default: True
148
148
Raise the error after a point ``x`` failed `retries`.
149
+ dynamic_args_provider : callable, optional
150
+ A callable that takes the learner as its sole argument and returns additional
151
+ arguments to pass to the function being learned. This allows you to dynamically
152
+ adjust parameters of the function based on the current state of the learner.
153
+ If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
154
+ instead of just `function(x)`.
149
155
allow_running_forever : bool, default: False
150
156
Allow the runner to run forever when the goal is None.
151
157
@@ -188,6 +194,7 @@ def __init__(
188
194
shutdown_executor : bool = False ,
189
195
retries : int = 0 ,
190
196
raise_if_retries_exceeded : bool = True ,
197
+ dynamic_args_provider : Callable [[LearnerType ], Any ] | None = None ,
191
198
allow_running_forever : bool = False ,
192
199
):
193
200
self .executor = _ensure_executor (executor )
@@ -228,6 +235,8 @@ def __init__(
228
235
next , itertools .count ()
229
236
) # some unique id to be associated with each point
230
237
238
+ self .dynamic_args_provider = dynamic_args_provider
239
+
231
240
def _get_max_tasks (self ) -> int :
232
241
return self ._max_tasks or _get_ncores (self .executor )
233
242
@@ -432,6 +441,12 @@ class BlockingRunner(BaseRunner):
432
441
the point is present in ``runner.failed``.
433
442
raise_if_retries_exceeded : bool, default: True
434
443
Raise the error after a point ``x`` failed `retries`.
444
+ dynamic_args_provider : callable, optional
445
+ A callable that takes the learner as its sole argument and returns additional
446
+ arguments to pass to the function being learned. This allows you to dynamically
447
+ adjust parameters of the function based on the current state of the learner.
448
+ If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
449
+ instead of just `function(x)`.
435
450
436
451
Attributes
437
452
----------
@@ -476,6 +491,7 @@ def __init__(
476
491
shutdown_executor : bool = False ,
477
492
retries : int = 0 ,
478
493
raise_if_retries_exceeded : bool = True ,
494
+ dynamic_args_provider : Callable [[LearnerType ], Any ] | None = None ,
479
495
) -> None :
480
496
if inspect .iscoroutinefunction (learner .function ):
481
497
raise ValueError ("Coroutine functions can only be used with 'AsyncRunner'." )
@@ -497,7 +513,14 @@ def __init__(
497
513
self ._run ()
498
514
499
515
def _submit (self , x : tuple [float , ...] | float | int ) -> FutureTypes :
500
- return self .executor .submit (self .learner .function , x )
516
+ args = (
517
+ (x ,)
518
+ if not self .dynamic_args_provider
519
+ else (x , self .dynamic_args_provider (self .learner ))
520
+ )
521
+ if self .dynamic_args_provider :
522
+ return self .executor .submit (self .learner .function , * args )
523
+ return self .executor .submit (self .learner .function , * args )
501
524
502
525
def _run (self ) -> None :
503
526
first_completed = concurrent .FIRST_COMPLETED
@@ -582,8 +605,12 @@ class AsyncRunner(BaseRunner):
582
605
the point is present in ``runner.failed``.
583
606
raise_if_retries_exceeded : bool, default: True
584
607
Raise the error after a point ``x`` failed `retries`.
585
- allow_running_forever : bool, default: True
586
- If True, the runner will run forever if the goal is not provided.
608
+ dynamic_args_provider : callable, optional
609
+ A callable that takes the learner as its sole argument and returns additional
610
+ arguments to pass to the function being learned. This allows you to dynamically
611
+ adjust parameters of the function based on the current state of the learner.
612
+ If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
613
+ instead of just `function(x)`.
587
614
588
615
Attributes
589
616
----------
@@ -636,6 +663,7 @@ def __init__(
636
663
ioloop = None ,
637
664
retries : int = 0 ,
638
665
raise_if_retries_exceeded : bool = True ,
666
+ dynamic_args_provider : Callable [[LearnerType ], Any ] | None = None ,
639
667
) -> None :
640
668
if (
641
669
executor is None
@@ -666,6 +694,7 @@ def __init__(
666
694
shutdown_executor = shutdown_executor ,
667
695
retries = retries ,
668
696
raise_if_retries_exceeded = raise_if_retries_exceeded ,
697
+ dynamic_args_provider = dynamic_args_provider ,
669
698
allow_running_forever = True ,
670
699
)
671
700
self .ioloop = ioloop or asyncio .get_event_loop ()
@@ -694,10 +723,15 @@ def __init__(
694
723
695
724
def _submit (self , x : Any ) -> asyncio .Task | asyncio .Future :
696
725
ioloop = self .ioloop
726
+ args = (
727
+ (x ,)
728
+ if not self .dynamic_args_provider
729
+ else (x , self .dynamic_args_provider (self .learner ))
730
+ )
697
731
if inspect .iscoroutinefunction (self .learner .function ):
698
- return ioloop .create_task (self .learner .function (x ))
732
+ return ioloop .create_task (self .learner .function (* args ))
699
733
else :
700
- return ioloop .run_in_executor (self .executor , self .learner .function , x )
734
+ return ioloop .run_in_executor (self .executor , self .learner .function , * args )
701
735
702
736
def status (self ) -> str :
703
737
"""Return the runner status as a string.
0 commit comments