1
1
import collections .abc
2
2
import itertools
3
3
import math
4
- import numbers
5
4
from copy import deepcopy
6
- from typing import (
7
- Any ,
8
- Callable ,
9
- Dict ,
10
- Iterable ,
11
- List ,
12
- Literal ,
13
- Optional ,
14
- Sequence ,
15
- Tuple ,
16
- Union ,
17
- )
5
+ from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple , Union
18
6
19
7
import cloudpickle
20
8
import numpy as np
25
13
from adaptive .learner .learnerND import volume
26
14
from adaptive .learner .triangulation import simplex_volume_in_embedding
27
15
from adaptive .notebook_integration import ensure_holoviews
28
- from adaptive .types import Float
16
+ from adaptive .types import Float , Int , Real
29
17
from adaptive .utils import cache_latest
30
18
31
- Point = Tuple [Float , Float ]
19
+ # -- types --
20
+
21
+ # Commonly used types
22
+ Interval = Union [Tuple [float , float ], Tuple [float , float , int ]]
23
+ NeighborsType = Dict [float , List [Optional [float ]]]
24
+
25
+ # Types for loss_per_interval functions
26
+ NoneFloat = Union [Float , None ]
27
+ NoneArray = Union [np .ndarray , None ]
28
+ XsType0 = Tuple [Float , Float ]
29
+ YsType0 = Union [Tuple [Float , Float ], Tuple [np .ndarray , np .ndarray ]]
30
+ XsType1 = Tuple [NoneFloat , NoneFloat , NoneFloat , NoneFloat ]
31
+ YsType1 = Union [
32
+ Tuple [NoneFloat , NoneFloat , NoneFloat , NoneFloat ],
33
+ Tuple [NoneArray , NoneArray , NoneArray , NoneArray ],
34
+ ]
35
+ XsTypeN = Tuple [NoneFloat , ...]
36
+ YsTypeN = Union [Tuple [NoneFloat , ...], Tuple [NoneArray , ...]]
37
+
38
+
39
+ __all__ = [
40
+ "uniform_loss" ,
41
+ "default_loss" ,
42
+ "abs_min_log_loss" ,
43
+ "triangle_loss" ,
44
+ "resolution_loss_function" ,
45
+ "curvature_loss_function" ,
46
+ "Learner1D" ,
47
+ ]
32
48
33
49
34
50
@uses_nth_neighbors (0 )
35
- def uniform_loss (xs : Point , ys : Any ) -> Float :
51
+ def uniform_loss (xs : XsType0 , ys : YsType0 ) -> Float :
36
52
"""Loss function that samples the domain uniformly.
37
53
38
54
Works with `~adaptive.Learner1D` only.
@@ -52,10 +68,7 @@ def uniform_loss(xs: Point, ys: Any) -> Float:
52
68
53
69
54
70
@uses_nth_neighbors (0 )
55
- def default_loss (
56
- xs : Point ,
57
- ys : Union [Tuple [Iterable [Float ], Iterable [Float ]], Point ],
58
- ) -> float :
71
+ def default_loss (xs : XsType0 , ys : YsType0 ) -> Float :
59
72
"""Calculate loss on a single interval.
60
73
61
74
Currently returns the rescaled length of the interval. If one of the
@@ -64,28 +77,23 @@ def default_loss(
64
77
"""
65
78
dx = xs [1 ] - xs [0 ]
66
79
if isinstance (ys [0 ], collections .abc .Iterable ):
67
- dy_vec = [abs (a - b ) for a , b in zip (* ys )]
80
+ dy_vec = np . array ( [abs (a - b ) for a , b in zip (* ys )])
68
81
return np .hypot (dx , dy_vec ).max ()
69
82
else :
70
83
dy = ys [1 ] - ys [0 ]
71
84
return np .hypot (dx , dy )
72
85
73
86
74
87
@uses_nth_neighbors (0 )
75
- def abs_min_log_loss (xs , ys ) :
88
+ def abs_min_log_loss (xs : XsType0 , ys : YsType0 ) -> Float :
76
89
"""Calculate loss of a single interval that prioritizes the absolute minimum."""
77
- ys = [ np .log (np .abs (y ).min ()) for y in ys ]
90
+ ys = tuple ( np .log (np .abs (y ).min ()) for y in ys )
78
91
return default_loss (xs , ys )
79
92
80
93
81
94
@uses_nth_neighbors (1 )
82
- def triangle_loss (
83
- xs : Sequence [Optional [Float ]],
84
- ys : Union [
85
- Iterable [Optional [Float ]],
86
- Iterable [Union [Iterable [Float ], None ]],
87
- ],
88
- ) -> float :
95
+ def triangle_loss (xs : XsType1 , ys : YsType1 ) -> Float :
96
+ assert len (xs ) == 4
89
97
xs = [x for x in xs if x is not None ]
90
98
ys = [y for y in ys if y is not None ]
91
99
@@ -102,7 +110,9 @@ def triangle_loss(
102
110
return sum (vol (pts [i : i + 3 ]) for i in range (N )) / N
103
111
104
112
105
- def resolution_loss_function (min_length = 0 , max_length = 1 ):
113
+ def resolution_loss_function (
114
+ min_length : Real = 0 , max_length : Real = 1
115
+ ) -> Callable [[XsType0 , YsType0 ], Float ]:
106
116
"""Loss function that is similar to the `default_loss` function, but you
107
117
can set the maximum and minimum size of an interval.
108
118
@@ -125,7 +135,7 @@ def resolution_loss_function(min_length=0, max_length=1):
125
135
"""
126
136
127
137
@uses_nth_neighbors (0 )
128
- def resolution_loss (xs , ys ) :
138
+ def resolution_loss (xs : XsType0 , ys : YsType0 ) -> Float :
129
139
loss = uniform_loss (xs , ys )
130
140
if loss < min_length :
131
141
# Return zero such that this interval won't be chosen again
@@ -140,11 +150,11 @@ def resolution_loss(xs, ys):
140
150
141
151
142
152
def curvature_loss_function (
143
- area_factor : float = 1 , euclid_factor : float = 0.02 , horizontal_factor : float = 0.02
144
- ) -> Callable :
153
+ area_factor : Real = 1 , euclid_factor : Real = 0.02 , horizontal_factor : Real = 0.02
154
+ ) -> Callable [[ XsType1 , YsType1 ], Float ] :
145
155
# XXX: add a doc-string
146
156
@uses_nth_neighbors (1 )
147
- def curvature_loss (xs , ys ) :
157
+ def curvature_loss (xs : XsType1 , ys : YsType1 ) -> Float :
148
158
xs_middle = xs [1 :3 ]
149
159
ys_middle = ys [1 :3 ]
150
160
@@ -160,7 +170,7 @@ def curvature_loss(xs, ys):
160
170
return curvature_loss
161
171
162
172
163
- def linspace (x_left : float , x_right : float , n : int ) -> List [float ]:
173
+ def linspace (x_left : Real , x_right : Real , n : Int ) -> List [Float ]:
164
174
"""This is equivalent to
165
175
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
166
176
but it is 15-30 times faster for small 'n'."""
@@ -172,7 +182,7 @@ def linspace(x_left: float, x_right: float, n: int) -> List[float]:
172
182
return [x_left + step * i for i in range (1 , n )]
173
183
174
184
175
- def _get_neighbors_from_list (xs : np .ndarray ) -> SortedDict :
185
+ def _get_neighbors_from_array (xs : np .ndarray ) -> NeighborsType :
176
186
xs = np .sort (xs )
177
187
xs_left = np .roll (xs , 1 ).tolist ()
178
188
xs_right = np .roll (xs , - 1 ).tolist ()
@@ -182,7 +192,9 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
182
192
return SortedDict (neighbors )
183
193
184
194
185
- def _get_intervals (x : float , neighbors : SortedDict , nth_neighbors : int ) -> Any :
195
+ def _get_intervals (
196
+ x : float , neighbors : NeighborsType , nth_neighbors : int
197
+ ) -> List [Tuple [float , float ]]:
186
198
nn = nth_neighbors
187
199
i = neighbors .index (x )
188
200
start = max (0 , i - nn - 1 )
@@ -237,10 +249,10 @@ class Learner1D(BaseLearner):
237
249
238
250
def __init__ (
239
251
self ,
240
- function : Callable ,
241
- bounds : Tuple [float , float ],
242
- loss_per_interval : Optional [Callable ] = None ,
243
- ) -> None :
252
+ function : Callable [[ Real ], Union [ Float , np . ndarray ]] ,
253
+ bounds : Tuple [Real , Real ],
254
+ loss_per_interval : Optional [Callable [[ XsTypeN , YsTypeN ], Float ] ] = None ,
255
+ ):
244
256
self .function = function # type: ignore
245
257
246
258
if hasattr (loss_per_interval , "nth_neighbors" ):
@@ -255,13 +267,13 @@ def __init__(
255
267
# the learners behavior in the tests.
256
268
self ._recompute_losses_factor = 2
257
269
258
- self .data = {}
259
- self .pending_points = set ()
270
+ self .data : Dict [ Real , Real ] = {}
271
+ self .pending_points : Set [ Real ] = set ()
260
272
261
273
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
262
274
# properties.
263
- self .neighbors = SortedDict ()
264
- self .neighbors_combined = SortedDict ()
275
+ self .neighbors : NeighborsType = SortedDict ()
276
+ self .neighbors_combined : NeighborsType = SortedDict ()
265
277
266
278
# Bounding box [[minx, maxx], [miny, maxy]].
267
279
self ._bbox = [list (bounds ), [np .inf , - np .inf ]]
@@ -319,14 +331,14 @@ def loss(self, real: bool = True) -> float:
319
331
max_interval , max_loss = losses .peekitem (0 )
320
332
return max_loss
321
333
322
- def _scale_x (self , x : Optional [float ]) -> Optional [float ]:
334
+ def _scale_x (self , x : Optional [Float ]) -> Optional [Float ]:
323
335
if x is None :
324
336
return None
325
337
return x / self ._scale [0 ]
326
338
327
339
def _scale_y (
328
- self , y : Optional [ Union [Float , np .ndarray ] ]
329
- ) -> Optional [ Union [Float , np .ndarray ] ]:
340
+ self , y : Union [Float , np .ndarray , None ]
341
+ ) -> Union [Float , np .ndarray , None ]:
330
342
if y is None :
331
343
return None
332
344
y_scale = self ._scale [1 ] or 1
@@ -418,7 +430,7 @@ def _update_losses(self, x: float, real: bool = True) -> None:
418
430
self .losses_combined [x , b ] = float ("inf" )
419
431
420
432
@staticmethod
421
- def _find_neighbors (x : float , neighbors : SortedDict ) -> Any :
433
+ def _find_neighbors (x : float , neighbors : NeighborsType ) -> Any :
422
434
if x in neighbors :
423
435
return neighbors [x ]
424
436
pos = neighbors .bisect_left (x )
@@ -427,7 +439,7 @@ def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
427
439
x_right = keys [pos ] if pos != len (neighbors ) else None
428
440
return x_left , x_right
429
441
430
- def _update_neighbors (self , x : float , neighbors : SortedDict ) -> None :
442
+ def _update_neighbors (self , x : float , neighbors : NeighborsType ) -> None :
431
443
if x not in neighbors : # The point is new
432
444
x_left , x_right = self ._find_neighbors (x , neighbors )
433
445
neighbors [x ] = [x_left , x_right ]
@@ -461,9 +473,7 @@ def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
461
473
self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
462
474
self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
463
475
464
- def tell (
465
- self , x : float , y : Union [Float , Sequence [numbers .Number ], np .ndarray ]
466
- ) -> None :
476
+ def tell (self , x : float , y : Union [Float , Sequence [Float ], np .ndarray ]) -> None :
467
477
if x in self .data :
468
478
# The point is already evaluated before
469
479
return
@@ -506,7 +516,17 @@ def tell_pending(self, x: float) -> None:
506
516
self ._update_neighbors (x , self .neighbors_combined )
507
517
self ._update_losses (x , real = False )
508
518
509
- def tell_many (self , xs : Sequence [float ], ys : Sequence [Any ], * , force = False ) -> None :
519
+ def tell_many (
520
+ self ,
521
+ xs : Sequence [Float ],
522
+ ys : Union [
523
+ Sequence [Float ],
524
+ Sequence [Sequence [Float ]],
525
+ Sequence [np .ndarray ],
526
+ ],
527
+ * ,
528
+ force : bool = False
529
+ ) -> None :
510
530
if not force and not (len (xs ) > 0.5 * len (self .data ) and len (xs ) > 2 ):
511
531
# Only run this more efficient method if there are
512
532
# at least 2 points and the amount of points added are
@@ -526,8 +546,8 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
526
546
points_combined = np .hstack ([points_pending , points ])
527
547
528
548
# Generate neighbors
529
- self .neighbors = _get_neighbors_from_list (points )
530
- self .neighbors_combined = _get_neighbors_from_list (points_combined )
549
+ self .neighbors = _get_neighbors_from_array (points )
550
+ self .neighbors_combined = _get_neighbors_from_array (points_combined )
531
551
532
552
# Update scale
533
553
self ._bbox [0 ] = [points_combined .min (), points_combined .max ()]
@@ -574,7 +594,7 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
574
594
# have an inf loss.
575
595
self ._update_interpolated_loss_in_interval (* ival )
576
596
577
- def ask (self , n : int , tell_pending : bool = True ) -> Any :
597
+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [ List [ float ], List [ float ]] :
578
598
"""Return 'n' points that are expected to maximally reduce the loss."""
579
599
points , loss_improvements = self ._ask_points_without_adding (n )
580
600
@@ -584,7 +604,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
584
604
585
605
return points , loss_improvements
586
606
587
- def _ask_points_without_adding (self , n : int ) -> Any :
607
+ def _ask_points_without_adding (self , n : int ) -> Tuple [ List [ float ], List [ float ]] :
588
608
"""Return 'n' points that are expected to maximally reduce the loss.
589
609
Without altering the state of the learner"""
590
610
# Find out how to divide the n points over the intervals
@@ -648,7 +668,7 @@ def _ask_points_without_adding(self, n: int) -> Any:
648
668
quals [(* xs , n + 1 )] = loss_qual * n / (n + 1 )
649
669
650
670
points = list (
651
- itertools .chain .from_iterable (linspace (a , b , n ) for (( a , b ) , n ) in quals )
671
+ itertools .chain .from_iterable (linspace (* ival , n ) for (* ival , n ) in quals )
652
672
)
653
673
654
674
loss_improvements = list (
@@ -663,11 +683,13 @@ def _ask_points_without_adding(self, n: int) -> Any:
663
683
664
684
return points , loss_improvements
665
685
666
- def _loss (self , mapping : ItemSortedDict , ival : Any ) -> Any :
686
+ def _loss (
687
+ self , mapping : Dict [Interval , float ], ival : Interval
688
+ ) -> Tuple [float , Interval ]:
667
689
loss = mapping [ival ]
668
690
return finite_loss (ival , loss , self ._scale [0 ])
669
691
670
- def plot (self , * , scatter_or_line : Literal [ "scatter" , "line" ] = "scatter" ):
692
+ def plot (self , * , scatter_or_line : str = "scatter" ):
671
693
"""Returns a plot of the evaluated data.
672
694
673
695
Parameters
@@ -734,7 +756,7 @@ def __setstate__(self, state):
734
756
self .losses_combined .update (losses_combined )
735
757
736
758
737
- def loss_manager (x_scale : float ) -> ItemSortedDict :
759
+ def loss_manager (x_scale : float ) -> Dict [ Interval , float ] :
738
760
def sort_key (ival , loss ):
739
761
loss , ival = finite_loss (ival , loss , x_scale )
740
762
return - loss , ival
@@ -743,8 +765,8 @@ def sort_key(ival, loss):
743
765
return sorted_dict
744
766
745
767
746
- def finite_loss (ival : Any , loss : float , x_scale : float ) -> Any :
747
- """Get the socalled finite_loss of an interval in order to be able to
768
+ def finite_loss (ival : Interval , loss : float , x_scale : float ) -> Tuple [ float , Interval ] :
769
+ """Get the so-called finite_loss of an interval in order to be able to
748
770
sort intervals that have infinite loss."""
749
771
# If the loss is infinite we return the
750
772
# distance between the two points.
0 commit comments