4
4
from contextlib import suppress
5
5
from functools import partial
6
6
from operator import itemgetter
7
+ from typing import Any , Callable , Dict , List , Set , Tuple , Union
7
8
8
9
import numpy as np
9
10
12
13
from adaptive .utils import cache_latest , named_product , restore
13
14
14
15
15
- def dispatch (child_functions , arg ) :
16
+ def dispatch (child_functions : List [ Callable ] , arg : Any ) -> Union [ Any ] :
16
17
index , x = arg
17
18
return child_functions [index ](x )
18
19
@@ -68,7 +69,9 @@ class BalancingLearner(BaseLearner):
68
69
behave in an undefined way. Change the `strategy` in that case.
69
70
"""
70
71
71
- def __init__ (self , learners , * , cdims = None , strategy = "loss_improvements" ):
72
+ def __init__ (
73
+ self , learners : List [BaseLearner ], * , cdims = None , strategy = "loss_improvements"
74
+ ) -> None :
72
75
self .learners = learners
73
76
74
77
# Naively we would make 'function' a method, but this causes problems
@@ -89,21 +92,21 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
89
92
self .strategy = strategy
90
93
91
94
@property
92
- def data (self ):
95
+ def data (self ) -> Dict [ Tuple [ int , Any ], Any ] :
93
96
data = {}
94
97
for i , l in enumerate (self .learners ):
95
98
data .update ({(i , p ): v for p , v in l .data .items ()})
96
99
return data
97
100
98
101
@property
99
- def pending_points (self ):
102
+ def pending_points (self ) -> Set [ Tuple [ int , Any ]] :
100
103
pending_points = set ()
101
104
for i , l in enumerate (self .learners ):
102
105
pending_points .update ({(i , p ) for p in l .pending_points })
103
106
return pending_points
104
107
105
108
@property
106
- def npoints (self ):
109
+ def npoints (self ) -> int :
107
110
return sum (l .npoints for l in self .learners )
108
111
109
112
@property
@@ -135,7 +138,9 @@ def strategy(self, strategy):
135
138
' strategy="npoints", or strategy="cycle" is implemented.'
136
139
)
137
140
138
- def _ask_and_tell_based_on_loss_improvements (self , n ):
141
+ def _ask_and_tell_based_on_loss_improvements (
142
+ self , n : int
143
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
139
144
selected = [] # tuples ((learner_index, point), loss_improvement)
140
145
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
141
146
for _ in range (n ):
@@ -158,7 +163,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
158
163
points , loss_improvements = map (list , zip (* selected ))
159
164
return points , loss_improvements
160
165
161
- def _ask_and_tell_based_on_loss (self , n ):
166
+ def _ask_and_tell_based_on_loss (
167
+ self , n : int
168
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
162
169
selected = [] # tuples ((learner_index, point), loss_improvement)
163
170
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
164
171
for _ in range (n ):
@@ -179,7 +186,9 @@ def _ask_and_tell_based_on_loss(self, n):
179
186
points , loss_improvements = map (list , zip (* selected ))
180
187
return points , loss_improvements
181
188
182
- def _ask_and_tell_based_on_npoints (self , n ):
189
+ def _ask_and_tell_based_on_npoints (
190
+ self , n : int
191
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
183
192
selected = [] # tuples ((learner_index, point), loss_improvement)
184
193
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
185
194
for _ in range (n ):
@@ -195,7 +204,9 @@ def _ask_and_tell_based_on_npoints(self, n):
195
204
points , loss_improvements = map (list , zip (* selected ))
196
205
return points , loss_improvements
197
206
198
- def _ask_and_tell_based_on_cycle (self , n ):
207
+ def _ask_and_tell_based_on_cycle (
208
+ self , n : int
209
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
199
210
points , loss_improvements = [], []
200
211
for _ in range (n ):
201
212
index = next (self ._cycle )
@@ -206,7 +217,9 @@ def _ask_and_tell_based_on_cycle(self, n):
206
217
207
218
return points , loss_improvements
208
219
209
- def ask (self , n , tell_pending = True ):
220
+ def ask (
221
+ self , n : int , tell_pending : bool = True
222
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
210
223
"""Chose points for learners."""
211
224
if n == 0 :
212
225
return [], []
@@ -217,20 +230,20 @@ def ask(self, n, tell_pending=True):
217
230
else :
218
231
return self ._ask_and_tell (n )
219
232
220
- def tell (self , x , y ) :
233
+ def tell (self , x : Tuple [ int , Any ], y : Any ) -> None :
221
234
index , x = x
222
235
self ._ask_cache .pop (index , None )
223
236
self ._loss .pop (index , None )
224
237
self ._pending_loss .pop (index , None )
225
238
self .learners [index ].tell (x , y )
226
239
227
- def tell_pending (self , x ) :
240
+ def tell_pending (self , x : Tuple [ int , Any ]) -> None :
228
241
index , x = x
229
242
self ._ask_cache .pop (index , None )
230
243
self ._loss .pop (index , None )
231
244
self .learners [index ].tell_pending (x )
232
245
233
- def _losses (self , real = True ):
246
+ def _losses (self , real : bool = True ) -> List [ float ] :
234
247
losses = []
235
248
loss_dict = self ._loss if real else self ._pending_loss
236
249
@@ -242,7 +255,7 @@ def _losses(self, real=True):
242
255
return losses
243
256
244
257
@cache_latest
245
- def loss (self , real = True ):
258
+ def loss (self , real : bool = True ) -> Union [ float ] :
246
259
losses = self ._losses (real )
247
260
return max (losses )
248
261
@@ -325,7 +338,9 @@ def remove_unfinished(self):
325
338
learner .remove_unfinished ()
326
339
327
340
@classmethod
328
- def from_product (cls , f , learner_type , learner_kwargs , combos ):
341
+ def from_product (
342
+ cls , f , learner_type , learner_kwargs , combos
343
+ ) -> "BalancingLearner" :
329
344
"""Create a `BalancingLearner` with learners of all combinations of
330
345
named variables’ values. The `cdims` will be set correctly, so calling
331
346
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -372,7 +387,7 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
372
387
learners .append (learner )
373
388
return cls (learners , cdims = arguments )
374
389
375
- def save (self , fname , compress = True ):
390
+ def save (self , fname : Callable , compress : bool = True ) -> None :
376
391
"""Save the data of the child learners into pickle files
377
392
in a directory.
378
393
@@ -410,7 +425,7 @@ def save(self, fname, compress=True):
410
425
for l in self .learners :
411
426
l .save (fname (l ), compress = compress )
412
427
413
- def load (self , fname , compress = True ):
428
+ def load (self , fname : Callable , compress : bool = True ) -> None :
414
429
"""Load the data of the child learners from pickle files
415
430
in a directory.
416
431
0 commit comments