diff --git a/adaptive/runner.py b/adaptive/runner.py index 4f877096..cd5d3987 100644 --- a/adaptive/runner.py +++ b/adaptive/runner.py @@ -146,6 +146,12 @@ class BaseRunner(metaclass=abc.ABCMeta): the point is present in ``runner.failed``. raise_if_retries_exceeded : bool, default: True Raise the error after a point ``x`` failed `retries`. + dynamic_args_provider : callable, optional + A callable that takes the learner as its sole argument and returns additional + arguments to pass to the function being learned. This allows you to dynamically + adjust parameters of the function based on the current state of the learner. + If provided, the function will be called as `function(x, dynamic_args_provider(learner))` + instead of just `function(x)`. allow_running_forever : bool, default: False Allow the runner to run forever when the goal is None. @@ -188,6 +194,7 @@ def __init__( shutdown_executor: bool = False, retries: int = 0, raise_if_retries_exceeded: bool = True, + dynamic_args_provider: Callable[[LearnerType], Any] | None = None, allow_running_forever: bool = False, ): self.executor = _ensure_executor(executor) @@ -228,6 +235,8 @@ def __init__( next, itertools.count() ) # some unique id to be associated with each point + self.dynamic_args_provider = dynamic_args_provider + def _get_max_tasks(self) -> int: return self._max_tasks or _get_ncores(self.executor) @@ -432,6 +441,12 @@ class BlockingRunner(BaseRunner): the point is present in ``runner.failed``. raise_if_retries_exceeded : bool, default: True Raise the error after a point ``x`` failed `retries`. + dynamic_args_provider : callable, optional + A callable that takes the learner as its sole argument and returns additional + arguments to pass to the function being learned. This allows you to dynamically + adjust parameters of the function based on the current state of the learner. + If provided, the function will be called as `function(x, dynamic_args_provider(learner))` + instead of just `function(x)`. Attributes ---------- @@ -476,6 +491,7 @@ def __init__( shutdown_executor: bool = False, retries: int = 0, raise_if_retries_exceeded: bool = True, + dynamic_args_provider: Callable[[LearnerType], Any] | None = None, ) -> None: if inspect.iscoroutinefunction(learner.function): raise ValueError("Coroutine functions can only be used with 'AsyncRunner'.") @@ -497,7 +513,14 @@ def __init__( self._run() def _submit(self, x: tuple[float, ...] | float | int) -> FutureTypes: - return self.executor.submit(self.learner.function, x) + args = ( + (x,) + if not self.dynamic_args_provider + else (x, self.dynamic_args_provider(self.learner)) + ) + if self.dynamic_args_provider: + return self.executor.submit(self.learner.function, *args) + return self.executor.submit(self.learner.function, *args) def _run(self) -> None: first_completed = concurrent.FIRST_COMPLETED @@ -582,8 +605,12 @@ class AsyncRunner(BaseRunner): the point is present in ``runner.failed``. raise_if_retries_exceeded : bool, default: True Raise the error after a point ``x`` failed `retries`. - allow_running_forever : bool, default: True - If True, the runner will run forever if the goal is not provided. + dynamic_args_provider : callable, optional + A callable that takes the learner as its sole argument and returns additional + arguments to pass to the function being learned. This allows you to dynamically + adjust parameters of the function based on the current state of the learner. + If provided, the function will be called as `function(x, dynamic_args_provider(learner))` + instead of just `function(x)`. Attributes ---------- @@ -636,6 +663,7 @@ def __init__( ioloop=None, retries: int = 0, raise_if_retries_exceeded: bool = True, + dynamic_args_provider: Callable[[LearnerType], Any] | None = None, ) -> None: if ( executor is None @@ -666,6 +694,7 @@ def __init__( shutdown_executor=shutdown_executor, retries=retries, raise_if_retries_exceeded=raise_if_retries_exceeded, + dynamic_args_provider=dynamic_args_provider, allow_running_forever=True, ) self.ioloop = ioloop or asyncio.get_event_loop() @@ -694,10 +723,15 @@ def __init__( def _submit(self, x: Any) -> asyncio.Task | asyncio.Future: ioloop = self.ioloop + args = ( + (x,) + if not self.dynamic_args_provider + else (x, self.dynamic_args_provider(self.learner)) + ) if inspect.iscoroutinefunction(self.learner.function): - return ioloop.create_task(self.learner.function(x)) + return ioloop.create_task(self.learner.function(*args)) else: - return ioloop.run_in_executor(self.executor, self.learner.function, x) + return ioloop.run_in_executor(self.executor, self.learner.function, *args) def status(self) -> str: """Return the runner status as a string.