14
14
import warnings
15
15
from contextlib import suppress
16
16
from datetime import datetime , timedelta
17
- from typing import Any , Callable , Union
17
+ from typing import TYPE_CHECKING , Any , Callable , Union
18
18
19
19
import loky
20
20
from _asyncio import Future , Task
27
27
SequenceLearner ,
28
28
)
29
29
from adaptive .notebook_integration import in_ipynb , live_info , live_plot
30
+ from adaptive .utils import SequentialExecutor
30
31
31
- _ThirdPartyClient = []
32
- _ThirdPartyExecutor = [loky .reusable_executor ._ReusablePoolExecutor ]
33
- _FutureTypes = [concurrent .Future , Future , Task ]
32
+ ExecutorTypes : TypeAlias = Union [
33
+ concurrent .ProcessPoolExecutor ,
34
+ concurrent .ThreadPoolExecutor ,
35
+ SequentialExecutor ,
36
+ loky .reusable_executor ._ReusablePoolExecutor ,
37
+ ]
38
+ FutureTypes : TypeAlias = Union [concurrent .Future , Future , Task ]
34
39
40
+ if TYPE_CHECKING :
41
+ import holoviews
35
42
36
43
try :
37
44
from typing import TypeAlias
38
45
except ImportError :
39
46
from typing_extensions import TypeAlias
40
47
48
+ try :
49
+ from typing import Literal
50
+ except ImportError :
51
+ from typing_extensions import Literal
52
+
53
+
41
54
try :
42
55
import ipyparallel
43
56
from ipyparallel .client .asyncresult import AsyncResult
44
57
45
58
with_ipyparallel = True
46
- _ThirdPartyClient .append (ipyparallel .Client )
47
- _ThirdPartyExecutor .append (ipyparallel .client .view .ViewExecutor )
48
- _FutureTypes .append (AsyncResult )
59
+ ExecutorTypes : TypeAlias = Union [
60
+ ExecutorTypes , ipyparallel .Client , ipyparallel .client .view .ViewExecutor
61
+ ]
62
+ FutureTypes : TypeAlias = Union [FutureTypes , AsyncResult ]
49
63
except ModuleNotFoundError :
50
64
with_ipyparallel = False
51
65
52
66
try :
53
67
import distributed
54
68
55
69
with_distributed = True
56
- _ThirdPartyClient .append (distributed .Client )
57
- _ThirdPartyExecutor .append (distributed .cfexecutor .ClientExecutor )
70
+ ExecutorTypes : TypeAlias = Union [
71
+ ExecutorTypes , distributed .Client , distributed .cfexecutor .ClientExecutor
72
+ ]
58
73
except ModuleNotFoundError :
59
74
with_distributed = False
60
75
61
76
try :
62
77
import mpi4py .futures
63
78
64
79
with_mpi4py = True
65
- _ThirdPartyExecutor . append ( mpi4py .futures .MPIPoolExecutor )
80
+ ExecutorTypes : TypeAlias = Union [ ExecutorTypes , mpi4py .futures .MPIPoolExecutor ]
66
81
except ModuleNotFoundError :
67
82
with_mpi4py = False
68
83
72
87
asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
73
88
74
89
75
- _ThirdPartyClient : TypeAlias = Union [tuple (_ThirdPartyClient )]
76
- _ThirdPartyExecutor : TypeAlias = Union [tuple (_ThirdPartyExecutor )]
77
- _FutureTypes : TypeAlias = Union [tuple (_FutureTypes )]
78
-
79
90
# -- Runner definitions
80
91
81
92
if platform .system () == "Linux" :
93
104
# -- Internal executor-related, things
94
105
95
106
96
- class SequentialExecutor (concurrent .Executor ):
97
- """A trivial executor that runs functions synchronously.
98
-
99
- This executor is mainly for testing.
100
- """
101
-
102
- def submit (self , fn : Callable , * args , ** kwargs ) -> _FutureTypes :
103
- fut : concurrent .Future = concurrent .Future ()
104
- try :
105
- fut .set_result (fn (* args , ** kwargs ))
106
- except Exception as e :
107
- fut .set_exception (e )
108
- return fut
109
-
110
- def map (self , fn , * iterable , timeout = None , chunksize = 1 ):
111
- return map (fn , iterable )
112
-
113
- def shutdown (self , wait = True ):
114
- pass
115
-
116
-
117
107
def _ensure_executor (
118
- executor : _ThirdPartyClient | concurrent . Executor | None ,
108
+ executor : ExecutorTypes | None ,
119
109
) -> concurrent .Executor :
120
110
if executor is None :
121
111
executor = concurrent .ProcessPoolExecutor ()
@@ -128,18 +118,14 @@ def _ensure_executor(
128
118
return executor .get_executor ()
129
119
else :
130
120
raise TypeError (
121
+ # TODO: check if this is correct. Isn't MPI,loky supported?
131
122
"Only a concurrent.futures.Executor, distributed.Client,"
132
123
" or ipyparallel.Client can be used."
133
124
)
134
125
135
126
136
127
def _get_ncores (
137
- ex : (
138
- _ThirdPartyExecutor
139
- | concurrent .ProcessPoolExecutor
140
- | concurrent .ThreadPoolExecutor
141
- | SequentialExecutor
142
- ),
128
+ ex : (ExecutorTypes ),
143
129
) -> int :
144
130
"""Return the maximum number of cores that an executor can use."""
145
131
if with_ipyparallel and isinstance (ex , ipyparallel .client .view .ViewExecutor ):
@@ -244,14 +230,7 @@ def __init__(
244
230
npoints_goal : int | None = None ,
245
231
end_time_goal : datetime | None = None ,
246
232
duration_goal : timedelta | int | float | None = None ,
247
- executor : (
248
- _ThirdPartyClient
249
- | _ThirdPartyExecutor
250
- | concurrent .ProcessPoolExecutor
251
- | concurrent .ThreadPoolExecutor
252
- | SequentialExecutor
253
- | None
254
- ) = None ,
233
+ executor : (ExecutorTypes | None ) = None ,
255
234
ntasks : int = None ,
256
235
log : bool = False ,
257
236
shutdown_executor : bool = False ,
@@ -356,7 +335,7 @@ def overhead(self) -> float:
356
335
357
336
def _process_futures (
358
337
self ,
359
- done_futs : set [_FutureTypes ],
338
+ done_futs : set [FutureTypes ],
360
339
) -> None :
361
340
for fut in done_futs :
362
341
pid = self ._pending_tasks .pop (fut )
@@ -381,7 +360,7 @@ def _process_futures(
381
360
382
361
def _get_futures (
383
362
self ,
384
- ) -> list [_FutureTypes ]:
363
+ ) -> list [FutureTypes ]:
385
364
# Launch tasks to replace the ones that completed
386
365
# on the last iteration, making sure to fill workers
387
366
# that have started since the last iteration.
@@ -403,7 +382,7 @@ def _get_futures(
403
382
futures = list (self ._pending_tasks .keys ())
404
383
return futures
405
384
406
- def _remove_unfinished (self ) -> list [_FutureTypes ]:
385
+ def _remove_unfinished (self ) -> list [FutureTypes ]:
407
386
# remove points with 'None' values from the learner
408
387
self .learner .remove_unfinished ()
409
388
# cancel any outstanding tasks
@@ -540,14 +519,7 @@ def __init__(
540
519
npoints_goal : int | None = None ,
541
520
end_time_goal : datetime | None = None ,
542
521
duration_goal : timedelta | int | float | None = None ,
543
- executor : (
544
- _ThirdPartyClient
545
- | _ThirdPartyExecutor
546
- | concurrent .ProcessPoolExecutor
547
- | concurrent .ThreadPoolExecutor
548
- | SequentialExecutor
549
- | None
550
- ) = None ,
522
+ executor : (ExecutorTypes | None ) = None ,
551
523
ntasks : int | None = None ,
552
524
log : bool = False ,
553
525
shutdown_executor : bool = False ,
@@ -573,7 +545,7 @@ def __init__(
573
545
)
574
546
self ._run ()
575
547
576
- def _submit (self , x : tuple [float , ...] | float | int ) -> _FutureTypes :
548
+ def _submit (self , x : tuple [float , ...] | float | int ) -> FutureTypes :
577
549
return self .executor .submit (self .learner .function , x )
578
550
579
551
def _run (self ) -> None :
@@ -706,14 +678,7 @@ def __init__(
706
678
npoints_goal : int | None = None ,
707
679
end_time_goal : datetime | None = None ,
708
680
duration_goal : timedelta | int | float | None = None ,
709
- executor : (
710
- _ThirdPartyClient
711
- | _ThirdPartyExecutor
712
- | concurrent .ProcessPoolExecutor
713
- | concurrent .ThreadPoolExecutor
714
- | SequentialExecutor
715
- | None
716
- ) = None ,
681
+ executor : (ExecutorTypes | None ) = None ,
717
682
ntasks : int | None = None ,
718
683
log : bool = False ,
719
684
shutdown_executor : bool = False ,
@@ -807,7 +772,14 @@ def cancel(self) -> None:
807
772
"""
808
773
self .task .cancel ()
809
774
810
- def live_plot (self , * , plotter = None , update_interval = 2 , name = None , normalize = True ):
775
+ def live_plot (
776
+ self ,
777
+ * ,
778
+ plotter : Callable [[BaseLearner ], holoviews .Element ] | None = None ,
779
+ update_interval : float = 2.0 ,
780
+ name : str = None ,
781
+ normalize : bool = True ,
782
+ ) -> holoviews .DynamicMap :
811
783
"""Live plotting of the learner's data.
812
784
813
785
Parameters
@@ -831,10 +803,14 @@ def live_plot(self, *, plotter=None, update_interval=2, name=None, normalize=Tru
831
803
The plot that automatically updates every `update_interval`.
832
804
"""
833
805
return live_plot (
834
- self , plotter = plotter , update_interval = update_interval , name = name
806
+ self ,
807
+ plotter = plotter ,
808
+ update_interval = update_interval ,
809
+ name = name ,
810
+ normalize = normalize ,
835
811
)
836
812
837
- def live_info (self , * , update_interval = 0.1 ):
813
+ def live_info (self , * , update_interval : float = 0.1 ) -> None :
838
814
"""Display live information about the runner.
839
815
840
816
Returns an interactive ipywidget that can be
@@ -984,7 +960,10 @@ def simple(
984
960
learner .tell (x , y )
985
961
986
962
987
- def replay_log (learner : BaseLearner , log ) -> None :
963
+ def replay_log (
964
+ learner : BaseLearner ,
965
+ log : list [tuple [Literal ["tell" ], Any , Any ] | tuple [Literal ["ask" ], int ]],
966
+ ) -> None :
988
967
"""Apply a sequence of method calls to a learner.
989
968
990
969
This is useful for debugging runners.
@@ -1002,8 +981,8 @@ def replay_log(learner: BaseLearner, log) -> None:
1002
981
1003
982
# --- Useful runner goals
1004
983
1005
-
1006
- def stop_after (* , seconds = 0 , minutes = 0 , hours = 0 ) -> Callable :
984
+ # TODO: deprecate
985
+ def stop_after (* , seconds = 0 , minutes = 0 , hours = 0 ) -> Callable [[ BaseLearner ], bool ] :
1007
986
"""Stop a runner after a specified time.
1008
987
1009
988
For example, to specify a runner that should stop after
@@ -1042,10 +1021,7 @@ def stop_after(*, seconds=0, minutes=0, hours=0) -> Callable:
1042
1021
1043
1022
class _TimeGoal :
1044
1023
def __init__ (self , dt : timedelta | datetime | int | float ):
1045
- if not isinstance (dt , (timedelta , datetime )):
1046
- self .dt = timedelta (seconds = dt )
1047
- else :
1048
- self .dt = dt
1024
+ self .dt = dt if isinstance (dt , (timedelta , datetime )) else timedelta (seconds = dt )
1049
1025
self .start_time = None
1050
1026
1051
1027
def __call__ (self , _ ):
0 commit comments