@@ -146,6 +146,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
146
146
the point is present in ``runner.failed``.
147
147
raise_if_retries_exceeded : bool, default: True
148
148
Raise the error after a point ``x`` failed `retries`.
149
+ dynamic_args_provider : callable, optional
150
+ A callable that takes the learner as its sole argument and returns additional
151
+ arguments to pass to the function being learned. This allows you to dynamically
152
+ adjust parameters of the function based on the current state of the learner.
153
+ If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
154
+ instead of just `function(x)`.
149
155
allow_running_forever : bool, default: False
150
156
Allow the runner to run forever when the goal is None.
151
157
@@ -188,6 +194,7 @@ def __init__(
188
194
shutdown_executor : bool = False ,
189
195
retries : int = 0 ,
190
196
raise_if_retries_exceeded : bool = True ,
197
+ dynamic_args_provider : Callable [[LearnerType ], Any ] | None = None ,
191
198
allow_running_forever : bool = False ,
192
199
):
193
200
self .executor = _ensure_executor (executor )
@@ -228,6 +235,8 @@ def __init__(
228
235
next , itertools .count ()
229
236
) # some unique id to be associated with each point
230
237
238
+ self .dynamic_args_provider = dynamic_args_provider
239
+
231
240
def _get_max_tasks (self ) -> int :
232
241
return self ._max_tasks or _get_ncores (self .executor )
233
242
@@ -432,6 +441,12 @@ class BlockingRunner(BaseRunner):
432
441
the point is present in ``runner.failed``.
433
442
raise_if_retries_exceeded : bool, default: True
434
443
Raise the error after a point ``x`` failed `retries`.
444
+ dynamic_args_provider : callable, optional
445
+ A callable that takes the learner as its sole argument and returns additional
446
+ arguments to pass to the function being learned. This allows you to dynamically
447
+ adjust parameters of the function based on the current state of the learner.
448
+ If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
449
+ instead of just `function(x)`.
435
450
436
451
Attributes
437
452
----------
@@ -476,6 +491,7 @@ def __init__(
476
491
shutdown_executor : bool = False ,
477
492
retries : int = 0 ,
478
493
raise_if_retries_exceeded : bool = True ,
494
+ dynamic_args_provider : Callable [[LearnerType ], Any ] | None = None ,
479
495
) -> None :
480
496
if inspect .iscoroutinefunction (learner .function ):
481
497
raise ValueError ("Coroutine functions can only be used with 'AsyncRunner'." )
@@ -497,6 +513,10 @@ def __init__(
497
513
self ._run ()
498
514
499
515
def _submit (self , x : tuple [float , ...] | float | int ) -> FutureTypes :
516
+ if self .dynamic_args_provider :
517
+ return self .executor .submit (
518
+ self .learner .function , x , self .dynamic_args_provider (self .learner )
519
+ )
500
520
return self .executor .submit (self .learner .function , x )
501
521
502
522
def _run (self ) -> None :
@@ -582,8 +602,12 @@ class AsyncRunner(BaseRunner):
582
602
the point is present in ``runner.failed``.
583
603
raise_if_retries_exceeded : bool, default: True
584
604
Raise the error after a point ``x`` failed `retries`.
585
- allow_running_forever : bool, default: True
586
- If True, the runner will run forever if the goal is not provided.
605
+ dynamic_args_provider : callable, optional
606
+ A callable that takes the learner as its sole argument and returns additional
607
+ arguments to pass to the function being learned. This allows you to dynamically
608
+ adjust parameters of the function based on the current state of the learner.
609
+ If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
610
+ instead of just `function(x)`.
587
611
588
612
Attributes
589
613
----------
@@ -636,6 +660,7 @@ def __init__(
636
660
ioloop = None ,
637
661
retries : int = 0 ,
638
662
raise_if_retries_exceeded : bool = True ,
663
+ dynamic_args_provider : Callable [[LearnerType ], Any ] | None = None ,
639
664
) -> None :
640
665
if (
641
666
executor is None
@@ -666,6 +691,7 @@ def __init__(
666
691
shutdown_executor = shutdown_executor ,
667
692
retries = retries ,
668
693
raise_if_retries_exceeded = raise_if_retries_exceeded ,
694
+ dynamic_args_provider = dynamic_args_provider ,
669
695
allow_running_forever = True ,
670
696
)
671
697
self .ioloop = ioloop or asyncio .get_event_loop ()
@@ -694,10 +720,15 @@ def __init__(
694
720
695
721
def _submit (self , x : Any ) -> asyncio .Task | asyncio .Future :
696
722
ioloop = self .ioloop
723
+ args = (
724
+ (x ,)
725
+ if not self .dynamic_args_provider
726
+ else (x , self .dynamic_args_provider (self .learner ))
727
+ )
697
728
if inspect .iscoroutinefunction (self .learner .function ):
698
- return ioloop .create_task (self .learner .function (x ))
729
+ return ioloop .create_task (self .learner .function (* args ))
699
730
else :
700
- return ioloop .run_in_executor (self .executor , self .learner .function , x )
731
+ return ioloop .run_in_executor (self .executor , self .learner .function , * args )
701
732
702
733
def status (self ) -> str :
703
734
"""Return the runner status as a string.
0 commit comments