1
+ import collections .abc
1
2
import itertools
2
3
import math
3
- from collections . abc import Iterable
4
+ import numbers
4
5
from copy import deepcopy
6
+ from typing import (
7
+ Any ,
8
+ Callable ,
9
+ Dict ,
10
+ Iterable ,
11
+ List ,
12
+ Literal ,
13
+ Optional ,
14
+ Sequence ,
15
+ Tuple ,
16
+ Union ,
17
+ )
5
18
6
19
import cloudpickle
7
20
import numpy as np
8
- import sortedcollections
9
- import sortedcontainers
21
+ from sortedcollections . recipes import ItemSortedDict
22
+ from sortedcontainers . sorteddict import SortedDict
10
23
11
24
from adaptive .learner .base_learner import BaseLearner , uses_nth_neighbors
12
25
from adaptive .learner .learnerND import volume
13
26
from adaptive .learner .triangulation import simplex_volume_in_embedding
14
27
from adaptive .notebook_integration import ensure_holoviews
28
+ from adaptive .types import Float
15
29
from adaptive .utils import cache_latest
16
30
31
+ Point = Tuple [Float , Float ]
32
+
17
33
18
34
@uses_nth_neighbors (0 )
19
- def uniform_loss (xs , ys ) :
35
+ def uniform_loss (xs : Point , ys : Any ) -> Float :
20
36
"""Loss function that samples the domain uniformly.
21
37
22
38
Works with `~adaptive.Learner1D` only.
@@ -36,17 +52,20 @@ def uniform_loss(xs, ys):
36
52
37
53
38
54
@uses_nth_neighbors (0 )
39
- def default_loss (xs , ys ):
55
+ def default_loss (
56
+ xs : Point ,
57
+ ys : Union [Tuple [Iterable [Float ], Iterable [Float ]], Point ],
58
+ ) -> float :
40
59
"""Calculate loss on a single interval.
41
60
42
61
Currently returns the rescaled length of the interval. If one of the
43
62
y-values is missing, returns 0 (so the intervals with missing data are
44
63
never touched. This behavior should be improved later.
45
64
"""
46
65
dx = xs [1 ] - xs [0 ]
47
- if isinstance (ys [0 ], Iterable ):
48
- dy = [abs (a - b ) for a , b in zip (* ys )]
49
- return np .hypot (dx , dy ).max ()
66
+ if isinstance (ys [0 ], collections . abc . Iterable ):
67
+ dy_vec = [abs (a - b ) for a , b in zip (* ys )]
68
+ return np .hypot (dx , dy_vec ).max ()
50
69
else :
51
70
dy = ys [1 ] - ys [0 ]
52
71
return np .hypot (dx , dy )
@@ -60,15 +79,21 @@ def abs_min_log_loss(xs, ys):
60
79
61
80
62
81
@uses_nth_neighbors (1 )
63
- def triangle_loss (xs , ys ):
82
+ def triangle_loss (
83
+ xs : Sequence [Optional [Float ]],
84
+ ys : Union [
85
+ Iterable [Optional [Float ]],
86
+ Iterable [Union [Iterable [Float ], None ]],
87
+ ],
88
+ ) -> float :
64
89
xs = [x for x in xs if x is not None ]
65
90
ys = [y for y in ys if y is not None ]
66
91
67
92
if len (xs ) == 2 : # we do not have enough points for a triangle
68
93
return xs [1 ] - xs [0 ]
69
94
70
95
N = len (xs ) - 2 # number of constructed triangles
71
- if isinstance (ys [0 ], Iterable ):
96
+ if isinstance (ys [0 ], collections . abc . Iterable ):
72
97
pts = [(x , * y ) for x , y in zip (xs , ys )]
73
98
vol = simplex_volume_in_embedding
74
99
else :
@@ -114,7 +139,9 @@ def resolution_loss(xs, ys):
114
139
return resolution_loss
115
140
116
141
117
- def curvature_loss_function (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
142
+ def curvature_loss_function (
143
+ area_factor : float = 1 , euclid_factor : float = 0.02 , horizontal_factor : float = 0.02
144
+ ) -> Callable :
118
145
# XXX: add a doc-string
119
146
@uses_nth_neighbors (1 )
120
147
def curvature_loss (xs , ys ):
@@ -133,7 +160,7 @@ def curvature_loss(xs, ys):
133
160
return curvature_loss
134
161
135
162
136
- def linspace (x_left , x_right , n ) :
163
+ def linspace (x_left : float , x_right : float , n : int ) -> List [ float ] :
137
164
"""This is equivalent to
138
165
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
139
166
but it is 15-30 times faster for small 'n'."""
@@ -145,17 +172,17 @@ def linspace(x_left, x_right, n):
145
172
return [x_left + step * i for i in range (1 , n )]
146
173
147
174
148
- def _get_neighbors_from_list (xs ) :
175
+ def _get_neighbors_from_list (xs : np . ndarray ) -> SortedDict :
149
176
xs = np .sort (xs )
150
177
xs_left = np .roll (xs , 1 ).tolist ()
151
178
xs_right = np .roll (xs , - 1 ).tolist ()
152
179
xs_left [0 ] = None
153
180
xs_right [- 1 ] = None
154
181
neighbors = {x : [x_L , x_R ] for x , x_L , x_R in zip (xs , xs_left , xs_right )}
155
- return sortedcontainers . SortedDict (neighbors )
182
+ return SortedDict (neighbors )
156
183
157
184
158
- def _get_intervals (x , neighbors , nth_neighbors ) :
185
+ def _get_intervals (x : float , neighbors : SortedDict , nth_neighbors : int ) -> Any :
159
186
nn = nth_neighbors
160
187
i = neighbors .index (x )
161
188
start = max (0 , i - nn - 1 )
@@ -208,8 +235,13 @@ class Learner1D(BaseLearner):
208
235
decorator for more information.
209
236
"""
210
237
211
- def __init__ (self , function , bounds , loss_per_interval = None ):
212
- self .function = function
238
+ def __init__ (
239
+ self ,
240
+ function : Callable ,
241
+ bounds : Tuple [float , float ],
242
+ loss_per_interval : Optional [Callable ] = None ,
243
+ ) -> None :
244
+ self .function = function # type: ignore
213
245
214
246
if hasattr (loss_per_interval , "nth_neighbors" ):
215
247
self .nth_neighbors = loss_per_interval .nth_neighbors
@@ -228,8 +260,8 @@ def __init__(self, function, bounds, loss_per_interval=None):
228
260
229
261
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
230
262
# properties.
231
- self .neighbors = sortedcontainers . SortedDict ()
232
- self .neighbors_combined = sortedcontainers . SortedDict ()
263
+ self .neighbors = SortedDict ()
264
+ self .neighbors_combined = SortedDict ()
233
265
234
266
# Bounding box [[minx, maxx], [miny, maxy]].
235
267
self ._bbox = [list (bounds ), [np .inf , - np .inf ]]
@@ -247,10 +279,10 @@ def __init__(self, function, bounds, loss_per_interval=None):
247
279
248
280
self .bounds = list (bounds )
249
281
250
- self ._vdim = None
282
+ self ._vdim : Optional [ int ] = None
251
283
252
284
@property
253
- def vdim (self ):
285
+ def vdim (self ) -> int :
254
286
"""Length of the output of ``learner.function``.
255
287
If the output is unsized (when it's a scalar)
256
288
then `vdim = 1`.
@@ -275,35 +307,37 @@ def to_numpy(self):
275
307
return np .array ([(x , * np .atleast_1d (y )) for x , y in sorted (self .data .items ())])
276
308
277
309
@property
278
- def npoints (self ):
310
+ def npoints (self ) -> int :
279
311
"""Number of evaluated points."""
280
312
return len (self .data )
281
313
282
314
@cache_latest
283
- def loss (self , real = True ):
315
+ def loss (self , real : bool = True ) -> float :
284
316
losses = self .losses if real else self .losses_combined
285
317
if not losses :
286
318
return np .inf
287
319
max_interval , max_loss = losses .peekitem (0 )
288
320
return max_loss
289
321
290
- def _scale_x (self , x ) :
322
+ def _scale_x (self , x : Optional [ float ]) -> Optional [ float ] :
291
323
if x is None :
292
324
return None
293
325
return x / self ._scale [0 ]
294
326
295
- def _scale_y (self , y ):
327
+ def _scale_y (
328
+ self , y : Optional [Union [Float , np .ndarray ]]
329
+ ) -> Optional [Union [Float , np .ndarray ]]:
296
330
if y is None :
297
331
return None
298
332
y_scale = self ._scale [1 ] or 1
299
333
return y / y_scale
300
334
301
- def _get_point_by_index (self , ind ) :
335
+ def _get_point_by_index (self , ind : int ) -> Optional [ float ] :
302
336
if ind < 0 or ind >= len (self .neighbors ):
303
337
return None
304
338
return self .neighbors .keys ()[ind ]
305
339
306
- def _get_loss_in_interval (self , x_left , x_right ) :
340
+ def _get_loss_in_interval (self , x_left : float , x_right : float ) -> float :
307
341
assert x_left is not None and x_right is not None
308
342
309
343
if x_right - x_left < self ._dx_eps :
@@ -323,7 +357,9 @@ def _get_loss_in_interval(self, x_left, x_right):
323
357
# we need to compute the loss for this interval
324
358
return self .loss_per_interval (xs_scaled , ys_scaled )
325
359
326
- def _update_interpolated_loss_in_interval (self , x_left , x_right ):
360
+ def _update_interpolated_loss_in_interval (
361
+ self , x_left : float , x_right : float
362
+ ) -> None :
327
363
if x_left is None or x_right is None :
328
364
return
329
365
@@ -339,7 +375,7 @@ def _update_interpolated_loss_in_interval(self, x_left, x_right):
339
375
self .losses_combined [a , b ] = (b - a ) * loss / dx
340
376
a = b
341
377
342
- def _update_losses (self , x , real = True ):
378
+ def _update_losses (self , x : float , real : bool = True ) -> None :
343
379
"""Update all losses that depend on x"""
344
380
# When we add a new point x, we should update the losses
345
381
# (x_left, x_right) are the "real" neighbors of 'x'.
@@ -382,7 +418,7 @@ def _update_losses(self, x, real=True):
382
418
self .losses_combined [x , b ] = float ("inf" )
383
419
384
420
@staticmethod
385
- def _find_neighbors (x , neighbors ) :
421
+ def _find_neighbors (x : float , neighbors : SortedDict ) -> Any :
386
422
if x in neighbors :
387
423
return neighbors [x ]
388
424
pos = neighbors .bisect_left (x )
@@ -391,14 +427,14 @@ def _find_neighbors(x, neighbors):
391
427
x_right = keys [pos ] if pos != len (neighbors ) else None
392
428
return x_left , x_right
393
429
394
- def _update_neighbors (self , x , neighbors ) :
430
+ def _update_neighbors (self , x : float , neighbors : SortedDict ) -> None :
395
431
if x not in neighbors : # The point is new
396
432
x_left , x_right = self ._find_neighbors (x , neighbors )
397
433
neighbors [x ] = [x_left , x_right ]
398
434
neighbors .get (x_left , [None , None ])[1 ] = x
399
435
neighbors .get (x_right , [None , None ])[0 ] = x
400
436
401
- def _update_scale (self , x , y ) :
437
+ def _update_scale (self , x : float , y : Union [ Float , np . ndarray ]) -> None :
402
438
"""Update the scale with which the x and y-values are scaled.
403
439
404
440
For a learner where the function returns a single scalar the scale
@@ -425,7 +461,9 @@ def _update_scale(self, x, y):
425
461
self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
426
462
self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
427
463
428
- def tell (self , x , y ):
464
+ def tell (
465
+ self , x : float , y : Union [Float , Sequence [numbers .Number ], np .ndarray ]
466
+ ) -> None :
429
467
if x in self .data :
430
468
# The point is already evaluated before
431
469
return
@@ -460,15 +498,15 @@ def tell(self, x, y):
460
498
461
499
self ._oldscale = deepcopy (self ._scale )
462
500
463
- def tell_pending (self , x ) :
501
+ def tell_pending (self , x : float ) -> None :
464
502
if x in self .data :
465
503
# The point is already evaluated before
466
504
return
467
505
self .pending_points .add (x )
468
506
self ._update_neighbors (x , self .neighbors_combined )
469
507
self ._update_losses (x , real = False )
470
508
471
- def tell_many (self , xs , ys , * , force = False ):
509
+ def tell_many (self , xs : Sequence [ float ] , ys : Sequence [ Any ] , * , force = False ) -> None :
472
510
if not force and not (len (xs ) > 0.5 * len (self .data ) and len (xs ) > 2 ):
473
511
# Only run this more efficient method if there are
474
512
# at least 2 points and the amount of points added are
@@ -536,7 +574,7 @@ def tell_many(self, xs, ys, *, force=False):
536
574
# have an inf loss.
537
575
self ._update_interpolated_loss_in_interval (* ival )
538
576
539
- def ask (self , n , tell_pending = True ):
577
+ def ask (self , n : int , tell_pending : bool = True ) -> Any :
540
578
"""Return 'n' points that are expected to maximally reduce the loss."""
541
579
points , loss_improvements = self ._ask_points_without_adding (n )
542
580
@@ -546,7 +584,7 @@ def ask(self, n, tell_pending=True):
546
584
547
585
return points , loss_improvements
548
586
549
- def _ask_points_without_adding (self , n ) :
587
+ def _ask_points_without_adding (self , n : int ) -> Any :
550
588
"""Return 'n' points that are expected to maximally reduce the loss.
551
589
Without altering the state of the learner"""
552
590
# Find out how to divide the n points over the intervals
@@ -573,7 +611,8 @@ def _ask_points_without_adding(self, n):
573
611
# Add bound intervals to quals if bounds were missing.
574
612
if len (self .data ) + len (self .pending_points ) == 0 :
575
613
# We don't have any points, so return a linspace with 'n' points.
576
- return np .linspace (* self .bounds , n ).tolist (), [np .inf ] * n
614
+ a , b = self .bounds
615
+ return np .linspace (a , b , n ).tolist (), [np .inf ] * n
577
616
578
617
quals = loss_manager (self ._scale [0 ])
579
618
if len (missing_bounds ) > 0 :
@@ -609,7 +648,7 @@ def _ask_points_without_adding(self, n):
609
648
quals [(* xs , n + 1 )] = loss_qual * n / (n + 1 )
610
649
611
650
points = list (
612
- itertools .chain .from_iterable (linspace (* ival , n ) for (* ival , n ) in quals )
651
+ itertools .chain .from_iterable (linspace (a , b , n ) for (( a , b ) , n ) in quals )
613
652
)
614
653
615
654
loss_improvements = list (
@@ -624,11 +663,11 @@ def _ask_points_without_adding(self, n):
624
663
625
664
return points , loss_improvements
626
665
627
- def _loss (self , mapping , ival ) :
666
+ def _loss (self , mapping : ItemSortedDict , ival : Any ) -> Any :
628
667
loss = mapping [ival ]
629
668
return finite_loss (ival , loss , self ._scale [0 ])
630
669
631
- def plot (self , * , scatter_or_line = "scatter" ):
670
+ def plot (self , * , scatter_or_line : Literal [ "scatter" , "line" ] = "scatter" ):
632
671
"""Returns a plot of the evaluated data.
633
672
634
673
Parameters
@@ -663,17 +702,18 @@ def plot(self, *, scatter_or_line="scatter"):
663
702
664
703
return p .redim (x = dict (range = plot_bounds ))
665
704
666
- def remove_unfinished (self ):
705
+ def remove_unfinished (self ) -> None :
667
706
self .pending_points = set ()
668
707
self .losses_combined = deepcopy (self .losses )
669
708
self .neighbors_combined = deepcopy (self .neighbors )
670
709
671
- def _get_data (self ):
710
+ def _get_data (self ) -> Dict [ float , float ] :
672
711
return self .data
673
712
674
- def _set_data (self , data ) :
713
+ def _set_data (self , data : Dict [ float , float ]) -> None :
675
714
if data :
676
- self .tell_many (* zip (* data .items ()))
715
+ xs , ys = zip (* data .items ())
716
+ self .tell_many (xs , ys )
677
717
678
718
def __getstate__ (self ):
679
719
return (
@@ -694,16 +734,16 @@ def __setstate__(self, state):
694
734
self .losses_combined .update (losses_combined )
695
735
696
736
697
- def loss_manager (x_scale ) :
737
+ def loss_manager (x_scale : float ) -> ItemSortedDict :
698
738
def sort_key (ival , loss ):
699
739
loss , ival = finite_loss (ival , loss , x_scale )
700
740
return - loss , ival
701
741
702
- sorted_dict = sortedcollections . ItemSortedDict (sort_key )
742
+ sorted_dict = ItemSortedDict (sort_key )
703
743
return sorted_dict
704
744
705
745
706
- def finite_loss (ival , loss , x_scale ) :
746
+ def finite_loss (ival : Any , loss : float , x_scale : float ) -> Any :
707
747
"""Get the socalled finite_loss of an interval in order to be able to
708
748
sort intervals that have infinite loss."""
709
749
# If the loss is infinite we return the
0 commit comments