Skip to content

Commit a428473

Browse files
committed
Add dynamic_args_provider to Runner
1 parent 6ec39e1 commit a428473

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

adaptive/runner.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
146146
the point is present in ``runner.failed``.
147147
raise_if_retries_exceeded : bool, default: True
148148
Raise the error after a point ``x`` failed `retries`.
149+
dynamic_args_provider : callable, optional
150+
A callable that takes the learner as its sole argument and returns additional
151+
arguments to pass to the function being learned. This allows you to dynamically
152+
adjust parameters of the function based on the current state of the learner.
153+
If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
154+
instead of just `function(x)`.
149155
allow_running_forever : bool, default: False
150156
Allow the runner to run forever when the goal is None.
151157
@@ -188,6 +194,7 @@ def __init__(
188194
shutdown_executor: bool = False,
189195
retries: int = 0,
190196
raise_if_retries_exceeded: bool = True,
197+
dynamic_args_provider: Callable[[LearnerType], Any] | None = None,
191198
allow_running_forever: bool = False,
192199
):
193200
self.executor = _ensure_executor(executor)
@@ -228,6 +235,8 @@ def __init__(
228235
next, itertools.count()
229236
) # some unique id to be associated with each point
230237

238+
self.dynamic_args_provider = dynamic_args_provider
239+
231240
def _get_max_tasks(self) -> int:
232241
return self._max_tasks or _get_ncores(self.executor)
233242

@@ -432,6 +441,12 @@ class BlockingRunner(BaseRunner):
432441
the point is present in ``runner.failed``.
433442
raise_if_retries_exceeded : bool, default: True
434443
Raise the error after a point ``x`` failed `retries`.
444+
dynamic_args_provider : callable, optional
445+
A callable that takes the learner as its sole argument and returns additional
446+
arguments to pass to the function being learned. This allows you to dynamically
447+
adjust parameters of the function based on the current state of the learner.
448+
If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
449+
instead of just `function(x)`.
435450
436451
Attributes
437452
----------
@@ -476,6 +491,7 @@ def __init__(
476491
shutdown_executor: bool = False,
477492
retries: int = 0,
478493
raise_if_retries_exceeded: bool = True,
494+
dynamic_args_provider: Callable[[LearnerType], Any] | None = None,
479495
) -> None:
480496
if inspect.iscoroutinefunction(learner.function):
481497
raise ValueError("Coroutine functions can only be used with 'AsyncRunner'.")
@@ -497,7 +513,14 @@ def __init__(
497513
self._run()
498514

499515
def _submit(self, x: tuple[float, ...] | float | int) -> FutureTypes:
500-
return self.executor.submit(self.learner.function, x)
516+
args = (
517+
(x,)
518+
if not self.dynamic_args_provider
519+
else (x, self.dynamic_args_provider(self.learner))
520+
)
521+
if self.dynamic_args_provider:
522+
return self.executor.submit(self.learner.function, *args)
523+
return self.executor.submit(self.learner.function, *args)
501524

502525
def _run(self) -> None:
503526
first_completed = concurrent.FIRST_COMPLETED
@@ -582,8 +605,12 @@ class AsyncRunner(BaseRunner):
582605
the point is present in ``runner.failed``.
583606
raise_if_retries_exceeded : bool, default: True
584607
Raise the error after a point ``x`` failed `retries`.
585-
allow_running_forever : bool, default: True
586-
If True, the runner will run forever if the goal is not provided.
608+
dynamic_args_provider : callable, optional
609+
A callable that takes the learner as its sole argument and returns additional
610+
arguments to pass to the function being learned. This allows you to dynamically
611+
adjust parameters of the function based on the current state of the learner.
612+
If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
613+
instead of just `function(x)`.
587614
588615
Attributes
589616
----------
@@ -636,6 +663,7 @@ def __init__(
636663
ioloop=None,
637664
retries: int = 0,
638665
raise_if_retries_exceeded: bool = True,
666+
dynamic_args_provider: Callable[[LearnerType], Any] | None = None,
639667
) -> None:
640668
if (
641669
executor is None
@@ -666,6 +694,7 @@ def __init__(
666694
shutdown_executor=shutdown_executor,
667695
retries=retries,
668696
raise_if_retries_exceeded=raise_if_retries_exceeded,
697+
dynamic_args_provider=dynamic_args_provider,
669698
allow_running_forever=True,
670699
)
671700
self.ioloop = ioloop or asyncio.get_event_loop()
@@ -694,10 +723,15 @@ def __init__(
694723

695724
def _submit(self, x: Any) -> asyncio.Task | asyncio.Future:
696725
ioloop = self.ioloop
726+
args = (
727+
(x,)
728+
if not self.dynamic_args_provider
729+
else (x, self.dynamic_args_provider(self.learner))
730+
)
697731
if inspect.iscoroutinefunction(self.learner.function):
698-
return ioloop.create_task(self.learner.function(x))
732+
return ioloop.create_task(self.learner.function(*args))
699733
else:
700-
return ioloop.run_in_executor(self.executor, self.learner.function, x)
734+
return ioloop.run_in_executor(self.executor, self.learner.function, *args)
701735

702736
def status(self) -> str:
703737
"""Return the runner status as a string.

0 commit comments

Comments
 (0)