Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic_args_provider to Runner #472

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions adaptive/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
the point is present in ``runner.failed``.
raise_if_retries_exceeded : bool, default: True
Raise the error after a point ``x`` failed `retries`.
dynamic_args_provider : callable, optional
A callable that takes the learner as its sole argument and returns additional
arguments to pass to the function being learned. This allows you to dynamically
adjust parameters of the function based on the current state of the learner.
If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
instead of just `function(x)`.
allow_running_forever : bool, default: False
Allow the runner to run forever when the goal is None.

Expand Down Expand Up @@ -188,6 +194,7 @@ def __init__(
shutdown_executor: bool = False,
retries: int = 0,
raise_if_retries_exceeded: bool = True,
dynamic_args_provider: Callable[[LearnerType], Any] | None = None,
allow_running_forever: bool = False,
):
self.executor = _ensure_executor(executor)
Expand Down Expand Up @@ -228,6 +235,8 @@ def __init__(
next, itertools.count()
) # some unique id to be associated with each point

self.dynamic_args_provider = dynamic_args_provider

def _get_max_tasks(self) -> int:
return self._max_tasks or _get_ncores(self.executor)

Expand Down Expand Up @@ -432,6 +441,12 @@ class BlockingRunner(BaseRunner):
the point is present in ``runner.failed``.
raise_if_retries_exceeded : bool, default: True
Raise the error after a point ``x`` failed `retries`.
dynamic_args_provider : callable, optional
A callable that takes the learner as its sole argument and returns additional
arguments to pass to the function being learned. This allows you to dynamically
adjust parameters of the function based on the current state of the learner.
If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
instead of just `function(x)`.

Attributes
----------
Expand Down Expand Up @@ -476,6 +491,7 @@ def __init__(
shutdown_executor: bool = False,
retries: int = 0,
raise_if_retries_exceeded: bool = True,
dynamic_args_provider: Callable[[LearnerType], Any] | None = None,
) -> None:
if inspect.iscoroutinefunction(learner.function):
raise ValueError("Coroutine functions can only be used with 'AsyncRunner'.")
Expand All @@ -497,7 +513,14 @@ def __init__(
self._run()

def _submit(self, x: tuple[float, ...] | float | int) -> FutureTypes:
return self.executor.submit(self.learner.function, x)
args = (
(x,)
if not self.dynamic_args_provider
else (x, self.dynamic_args_provider(self.learner))
)
if self.dynamic_args_provider:
return self.executor.submit(self.learner.function, *args)
return self.executor.submit(self.learner.function, *args)

def _run(self) -> None:
first_completed = concurrent.FIRST_COMPLETED
Expand Down Expand Up @@ -582,8 +605,12 @@ class AsyncRunner(BaseRunner):
the point is present in ``runner.failed``.
raise_if_retries_exceeded : bool, default: True
Raise the error after a point ``x`` failed `retries`.
allow_running_forever : bool, default: True
If True, the runner will run forever if the goal is not provided.
dynamic_args_provider : callable, optional
A callable that takes the learner as its sole argument and returns additional
arguments to pass to the function being learned. This allows you to dynamically
adjust parameters of the function based on the current state of the learner.
If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
instead of just `function(x)`.

Attributes
----------
Expand Down Expand Up @@ -636,6 +663,7 @@ def __init__(
ioloop=None,
retries: int = 0,
raise_if_retries_exceeded: bool = True,
dynamic_args_provider: Callable[[LearnerType], Any] | None = None,
) -> None:
if (
executor is None
Expand Down Expand Up @@ -666,6 +694,7 @@ def __init__(
shutdown_executor=shutdown_executor,
retries=retries,
raise_if_retries_exceeded=raise_if_retries_exceeded,
dynamic_args_provider=dynamic_args_provider,
allow_running_forever=True,
)
self.ioloop = ioloop or asyncio.get_event_loop()
Expand Down Expand Up @@ -694,10 +723,15 @@ def __init__(

def _submit(self, x: Any) -> asyncio.Task | asyncio.Future:
ioloop = self.ioloop
args = (
(x,)
if not self.dynamic_args_provider
else (x, self.dynamic_args_provider(self.learner))
)
if inspect.iscoroutinefunction(self.learner.function):
return ioloop.create_task(self.learner.function(x))
return ioloop.create_task(self.learner.function(*args))
else:
return ioloop.run_in_executor(self.executor, self.learner.function, x)
return ioloop.run_in_executor(self.executor, self.learner.function, *args)

def status(self) -> str:
"""Return the runner status as a string.
Expand Down
Loading