Skip to content

Commit df62dcf

Browse files
committed
Add dynamic_args_provider to Runner
1 parent 6ec39e1 commit df62dcf

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

adaptive/runner.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
146146
the point is present in ``runner.failed``.
147147
raise_if_retries_exceeded : bool, default: True
148148
Raise the error after a point ``x`` failed `retries`.
149+
dynamic_args_provider : callable, optional
150+
A callable that takes the learner as its sole argument and returns additional
151+
arguments to pass to the function being learned. This allows you to dynamically
152+
adjust parameters of the function based on the current state of the learner.
153+
If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
154+
instead of just `function(x)`.
149155
allow_running_forever : bool, default: False
150156
Allow the runner to run forever when the goal is None.
151157
@@ -188,6 +194,7 @@ def __init__(
188194
shutdown_executor: bool = False,
189195
retries: int = 0,
190196
raise_if_retries_exceeded: bool = True,
197+
dynamic_args_provider: Callable[[LearnerType], Any] | None = None,
191198
allow_running_forever: bool = False,
192199
):
193200
self.executor = _ensure_executor(executor)
@@ -228,6 +235,8 @@ def __init__(
228235
next, itertools.count()
229236
) # some unique id to be associated with each point
230237

238+
self.dynamic_args_provider = dynamic_args_provider
239+
231240
def _get_max_tasks(self) -> int:
232241
return self._max_tasks or _get_ncores(self.executor)
233242

@@ -432,6 +441,12 @@ class BlockingRunner(BaseRunner):
432441
the point is present in ``runner.failed``.
433442
raise_if_retries_exceeded : bool, default: True
434443
Raise the error after a point ``x`` failed `retries`.
444+
dynamic_args_provider : callable, optional
445+
A callable that takes the learner as its sole argument and returns additional
446+
arguments to pass to the function being learned. This allows you to dynamically
447+
adjust parameters of the function based on the current state of the learner.
448+
If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
449+
instead of just `function(x)`.
435450
436451
Attributes
437452
----------
@@ -476,6 +491,7 @@ def __init__(
476491
shutdown_executor: bool = False,
477492
retries: int = 0,
478493
raise_if_retries_exceeded: bool = True,
494+
dynamic_args_provider: Callable[[LearnerType], Any] | None = None,
479495
) -> None:
480496
if inspect.iscoroutinefunction(learner.function):
481497
raise ValueError("Coroutine functions can only be used with 'AsyncRunner'.")
@@ -497,6 +513,10 @@ def __init__(
497513
self._run()
498514

499515
def _submit(self, x: tuple[float, ...] | float | int) -> FutureTypes:
516+
if self.dynamic_args_provider:
517+
return self.executor.submit(
518+
self.learner.function, x, self.dynamic_args_provider(self.learner)
519+
)
500520
return self.executor.submit(self.learner.function, x)
501521

502522
def _run(self) -> None:
@@ -582,8 +602,12 @@ class AsyncRunner(BaseRunner):
582602
the point is present in ``runner.failed``.
583603
raise_if_retries_exceeded : bool, default: True
584604
Raise the error after a point ``x`` failed `retries`.
585-
allow_running_forever : bool, default: True
586-
If True, the runner will run forever if the goal is not provided.
605+
dynamic_args_provider : callable, optional
606+
A callable that takes the learner as its sole argument and returns additional
607+
arguments to pass to the function being learned. This allows you to dynamically
608+
adjust parameters of the function based on the current state of the learner.
609+
If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
610+
instead of just `function(x)`.
587611
588612
Attributes
589613
----------
@@ -636,6 +660,7 @@ def __init__(
636660
ioloop=None,
637661
retries: int = 0,
638662
raise_if_retries_exceeded: bool = True,
663+
dynamic_args_provider: Callable[[LearnerType], Any] | None = None,
639664
) -> None:
640665
if (
641666
executor is None
@@ -666,6 +691,7 @@ def __init__(
666691
shutdown_executor=shutdown_executor,
667692
retries=retries,
668693
raise_if_retries_exceeded=raise_if_retries_exceeded,
694+
dynamic_args_provider=dynamic_args_provider,
669695
allow_running_forever=True,
670696
)
671697
self.ioloop = ioloop or asyncio.get_event_loop()
@@ -694,10 +720,15 @@ def __init__(
694720

695721
def _submit(self, x: Any) -> asyncio.Task | asyncio.Future:
696722
ioloop = self.ioloop
723+
args = (
724+
(x,)
725+
if not self.dynamic_args_provider
726+
else (x, self.dynamic_args_provider(self.learner))
727+
)
697728
if inspect.iscoroutinefunction(self.learner.function):
698-
return ioloop.create_task(self.learner.function(x))
729+
return ioloop.create_task(self.learner.function(*args))
699730
else:
700-
return ioloop.run_in_executor(self.executor, self.learner.function, x)
731+
return ioloop.run_in_executor(self.executor, self.learner.function, *args)
701732

702733
def status(self) -> str:
703734
"""Return the runner status as a string.

0 commit comments

Comments
 (0)