Skip to content

Commit 2c56c58

Browse files
committed
Redo typing of ExecutorTypes and FutureTypes
1 parent 886c66f commit 2c56c58

File tree

3 files changed

+2077
-84
lines changed

3 files changed

+2077
-84
lines changed

adaptive/runner.py

+56-80
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import warnings
1515
from contextlib import suppress
1616
from datetime import datetime, timedelta
17-
from typing import Any, Callable, Union
17+
from typing import TYPE_CHECKING, Any, Callable, Union
1818

1919
import loky
2020
from _asyncio import Future, Task
@@ -27,42 +27,57 @@
2727
SequenceLearner,
2828
)
2929
from adaptive.notebook_integration import in_ipynb, live_info, live_plot
30+
from adaptive.utils import SequentialExecutor
3031

31-
_ThirdPartyClient = []
32-
_ThirdPartyExecutor = [loky.reusable_executor._ReusablePoolExecutor]
33-
_FutureTypes = [concurrent.Future, Future, Task]
32+
ExecutorTypes: TypeAlias = Union[
33+
concurrent.ProcessPoolExecutor,
34+
concurrent.ThreadPoolExecutor,
35+
SequentialExecutor,
36+
loky.reusable_executor._ReusablePoolExecutor,
37+
]
38+
FutureTypes: TypeAlias = Union[concurrent.Future, Future, Task]
3439

40+
if TYPE_CHECKING:
41+
import holoviews
3542

3643
try:
3744
from typing import TypeAlias
3845
except ImportError:
3946
from typing_extensions import TypeAlias
4047

48+
try:
49+
from typing import Literal
50+
except ImportError:
51+
from typing_extensions import Literal
52+
53+
4154
try:
4255
import ipyparallel
4356
from ipyparallel.client.asyncresult import AsyncResult
4457

4558
with_ipyparallel = True
46-
_ThirdPartyClient.append(ipyparallel.Client)
47-
_ThirdPartyExecutor.append(ipyparallel.client.view.ViewExecutor)
48-
_FutureTypes.append(AsyncResult)
59+
ExecutorTypes: TypeAlias = Union[
60+
ExecutorTypes, ipyparallel.Client, ipyparallel.client.view.ViewExecutor
61+
]
62+
FutureTypes: TypeAlias = Union[FutureTypes, AsyncResult]
4963
except ModuleNotFoundError:
5064
with_ipyparallel = False
5165

5266
try:
5367
import distributed
5468

5569
with_distributed = True
56-
_ThirdPartyClient.append(distributed.Client)
57-
_ThirdPartyExecutor.append(distributed.cfexecutor.ClientExecutor)
70+
ExecutorTypes: TypeAlias = Union[
71+
ExecutorTypes, distributed.Client, distributed.cfexecutor.ClientExecutor
72+
]
5873
except ModuleNotFoundError:
5974
with_distributed = False
6075

6176
try:
6277
import mpi4py.futures
6378

6479
with_mpi4py = True
65-
_ThirdPartyExecutor.append(mpi4py.futures.MPIPoolExecutor)
80+
ExecutorTypes: TypeAlias = Union[ExecutorTypes, mpi4py.futures.MPIPoolExecutor]
6681
except ModuleNotFoundError:
6782
with_mpi4py = False
6883

@@ -72,10 +87,6 @@
7287
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
7388

7489

75-
_ThirdPartyClient: TypeAlias = Union[tuple(_ThirdPartyClient)]
76-
_ThirdPartyExecutor: TypeAlias = Union[tuple(_ThirdPartyExecutor)]
77-
_FutureTypes: TypeAlias = Union[tuple(_FutureTypes)]
78-
7990
# -- Runner definitions
8091

8192
if platform.system() == "Linux":
@@ -93,29 +104,8 @@
93104
# -- Internal executor-related, things
94105

95106

96-
class SequentialExecutor(concurrent.Executor):
97-
"""A trivial executor that runs functions synchronously.
98-
99-
This executor is mainly for testing.
100-
"""
101-
102-
def submit(self, fn: Callable, *args, **kwargs) -> _FutureTypes:
103-
fut: concurrent.Future = concurrent.Future()
104-
try:
105-
fut.set_result(fn(*args, **kwargs))
106-
except Exception as e:
107-
fut.set_exception(e)
108-
return fut
109-
110-
def map(self, fn, *iterable, timeout=None, chunksize=1):
111-
return map(fn, iterable)
112-
113-
def shutdown(self, wait=True):
114-
pass
115-
116-
117107
def _ensure_executor(
118-
executor: _ThirdPartyClient | concurrent.Executor | None,
108+
executor: ExecutorTypes | None,
119109
) -> concurrent.Executor:
120110
if executor is None:
121111
executor = concurrent.ProcessPoolExecutor()
@@ -128,18 +118,14 @@ def _ensure_executor(
128118
return executor.get_executor()
129119
else:
130120
raise TypeError(
121+
# TODO: check if this is correct. Isn't MPI,loky supported?
131122
"Only a concurrent.futures.Executor, distributed.Client,"
132123
" or ipyparallel.Client can be used."
133124
)
134125

135126

136127
def _get_ncores(
137-
ex: (
138-
_ThirdPartyExecutor
139-
| concurrent.ProcessPoolExecutor
140-
| concurrent.ThreadPoolExecutor
141-
| SequentialExecutor
142-
),
128+
ex: (ExecutorTypes),
143129
) -> int:
144130
"""Return the maximum number of cores that an executor can use."""
145131
if with_ipyparallel and isinstance(ex, ipyparallel.client.view.ViewExecutor):
@@ -244,14 +230,7 @@ def __init__(
244230
npoints_goal: int | None = None,
245231
end_time_goal: datetime | None = None,
246232
duration_goal: timedelta | int | float | None = None,
247-
executor: (
248-
_ThirdPartyClient
249-
| _ThirdPartyExecutor
250-
| concurrent.ProcessPoolExecutor
251-
| concurrent.ThreadPoolExecutor
252-
| SequentialExecutor
253-
| None
254-
) = None,
233+
executor: (ExecutorTypes | None) = None,
255234
ntasks: int = None,
256235
log: bool = False,
257236
shutdown_executor: bool = False,
@@ -356,7 +335,7 @@ def overhead(self) -> float:
356335

357336
def _process_futures(
358337
self,
359-
done_futs: set[_FutureTypes],
338+
done_futs: set[FutureTypes],
360339
) -> None:
361340
for fut in done_futs:
362341
pid = self._pending_tasks.pop(fut)
@@ -381,7 +360,7 @@ def _process_futures(
381360

382361
def _get_futures(
383362
self,
384-
) -> list[_FutureTypes]:
363+
) -> list[FutureTypes]:
385364
# Launch tasks to replace the ones that completed
386365
# on the last iteration, making sure to fill workers
387366
# that have started since the last iteration.
@@ -403,7 +382,7 @@ def _get_futures(
403382
futures = list(self._pending_tasks.keys())
404383
return futures
405384

406-
def _remove_unfinished(self) -> list[_FutureTypes]:
385+
def _remove_unfinished(self) -> list[FutureTypes]:
407386
# remove points with 'None' values from the learner
408387
self.learner.remove_unfinished()
409388
# cancel any outstanding tasks
@@ -540,14 +519,7 @@ def __init__(
540519
npoints_goal: int | None = None,
541520
end_time_goal: datetime | None = None,
542521
duration_goal: timedelta | int | float | None = None,
543-
executor: (
544-
_ThirdPartyClient
545-
| _ThirdPartyExecutor
546-
| concurrent.ProcessPoolExecutor
547-
| concurrent.ThreadPoolExecutor
548-
| SequentialExecutor
549-
| None
550-
) = None,
522+
executor: (ExecutorTypes | None) = None,
551523
ntasks: int | None = None,
552524
log: bool = False,
553525
shutdown_executor: bool = False,
@@ -573,7 +545,7 @@ def __init__(
573545
)
574546
self._run()
575547

576-
def _submit(self, x: tuple[float, ...] | float | int) -> _FutureTypes:
548+
def _submit(self, x: tuple[float, ...] | float | int) -> FutureTypes:
577549
return self.executor.submit(self.learner.function, x)
578550

579551
def _run(self) -> None:
@@ -706,14 +678,7 @@ def __init__(
706678
npoints_goal: int | None = None,
707679
end_time_goal: datetime | None = None,
708680
duration_goal: timedelta | int | float | None = None,
709-
executor: (
710-
_ThirdPartyClient
711-
| _ThirdPartyExecutor
712-
| concurrent.ProcessPoolExecutor
713-
| concurrent.ThreadPoolExecutor
714-
| SequentialExecutor
715-
| None
716-
) = None,
681+
executor: (ExecutorTypes | None) = None,
717682
ntasks: int | None = None,
718683
log: bool = False,
719684
shutdown_executor: bool = False,
@@ -807,7 +772,14 @@ def cancel(self) -> None:
807772
"""
808773
self.task.cancel()
809774

810-
def live_plot(self, *, plotter=None, update_interval=2, name=None, normalize=True):
775+
def live_plot(
776+
self,
777+
*,
778+
plotter: Callable[[BaseLearner], holoviews.Element] | None = None,
779+
update_interval: float = 2.0,
780+
name: str = None,
781+
normalize: bool = True,
782+
) -> holoviews.DynamicMap:
811783
"""Live plotting of the learner's data.
812784
813785
Parameters
@@ -831,10 +803,14 @@ def live_plot(self, *, plotter=None, update_interval=2, name=None, normalize=Tru
831803
The plot that automatically updates every `update_interval`.
832804
"""
833805
return live_plot(
834-
self, plotter=plotter, update_interval=update_interval, name=name
806+
self,
807+
plotter=plotter,
808+
update_interval=update_interval,
809+
name=name,
810+
normalize=normalize,
835811
)
836812

837-
def live_info(self, *, update_interval=0.1):
813+
def live_info(self, *, update_interval: float = 0.1) -> None:
838814
"""Display live information about the runner.
839815
840816
Returns an interactive ipywidget that can be
@@ -984,7 +960,10 @@ def simple(
984960
learner.tell(x, y)
985961

986962

987-
def replay_log(learner: BaseLearner, log) -> None:
963+
def replay_log(
964+
learner: BaseLearner,
965+
log: list[tuple[Literal["tell"], Any, Any] | tuple[Literal["ask"], int]],
966+
) -> None:
988967
"""Apply a sequence of method calls to a learner.
989968
990969
This is useful for debugging runners.
@@ -1002,8 +981,8 @@ def replay_log(learner: BaseLearner, log) -> None:
1002981

1003982
# --- Useful runner goals
1004983

1005-
1006-
def stop_after(*, seconds=0, minutes=0, hours=0) -> Callable:
984+
# TODO: deprecate
985+
def stop_after(*, seconds=0, minutes=0, hours=0) -> Callable[[BaseLearner], bool]:
1007986
"""Stop a runner after a specified time.
1008987
1009988
For example, to specify a runner that should stop after
@@ -1042,10 +1021,7 @@ def stop_after(*, seconds=0, minutes=0, hours=0) -> Callable:
10421021

10431022
class _TimeGoal:
10441023
def __init__(self, dt: timedelta | datetime | int | float):
1045-
if not isinstance(dt, (timedelta, datetime)):
1046-
self.dt = timedelta(seconds=dt)
1047-
else:
1048-
self.dt = dt
1024+
self.dt = dt if isinstance(dt, (timedelta, datetime)) else timedelta(seconds=dt)
10491025
self.start_time = None
10501026

10511027
def __call__(self, _):

adaptive/utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4+
import concurrent.futures as concurrent
45
import functools
56
import gzip
67
import inspect
@@ -155,3 +156,24 @@ def partial_function_from_dataframe(function, df, function_prefix: str = "functi
155156
" The DataFrame's value will be used."
156157
)
157158
return functools.partial(function, **kwargs)
159+
160+
161+
class SequentialExecutor(concurrent.Executor):
162+
"""A trivial executor that runs functions synchronously.
163+
164+
This executor is mainly for testing.
165+
"""
166+
167+
def submit(self, fn: Callable, *args, **kwargs) -> concurrent.Future:
168+
fut: concurrent.Future = concurrent.Future()
169+
try:
170+
fut.set_result(fn(*args, **kwargs))
171+
except Exception as e:
172+
fut.set_exception(e)
173+
return fut
174+
175+
def map(self, fn, *iterable, timeout=None, chunksize=1):
176+
return map(fn, iterable)
177+
178+
def shutdown(self, wait=True):
179+
pass

0 commit comments

Comments
 (0)