Skip to content

Commit 3d359c3

Browse files
committed
improve typing
1 parent 9bfbeb8 commit 3d359c3

File tree

1 file changed

+87
-65
lines changed

1 file changed

+87
-65
lines changed

adaptive/learner/learner1D.py

+87-65
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,8 @@
11
import collections.abc
22
import itertools
33
import math
4-
import numbers
54
from copy import deepcopy
6-
from typing import (
7-
Any,
8-
Callable,
9-
Dict,
10-
Iterable,
11-
List,
12-
Literal,
13-
Optional,
14-
Sequence,
15-
Tuple,
16-
Union,
17-
)
5+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
186

197
import cloudpickle
208
import numpy as np
@@ -25,14 +13,42 @@
2513
from adaptive.learner.learnerND import volume
2614
from adaptive.learner.triangulation import simplex_volume_in_embedding
2715
from adaptive.notebook_integration import ensure_holoviews
28-
from adaptive.types import Float
16+
from adaptive.types import Float, Int, Real
2917
from adaptive.utils import cache_latest
3018

31-
Point = Tuple[Float, Float]
19+
# -- types --
20+
21+
# Commonly used types
22+
Interval = Union[Tuple[float, float], Tuple[float, float, int]]
23+
NeighborsType = Dict[float, List[Optional[float]]]
24+
25+
# Types for loss_per_interval functions
26+
NoneFloat = Union[Float, None]
27+
NoneArray = Union[np.ndarray, None]
28+
XsType0 = Tuple[Float, Float]
29+
YsType0 = Union[Tuple[Float, Float], Tuple[np.ndarray, np.ndarray]]
30+
XsType1 = Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat]
31+
YsType1 = Union[
32+
Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat],
33+
Tuple[NoneArray, NoneArray, NoneArray, NoneArray],
34+
]
35+
XsTypeN = Tuple[NoneFloat, ...]
36+
YsTypeN = Union[Tuple[NoneFloat, ...], Tuple[NoneArray, ...]]
37+
38+
39+
__all__ = [
40+
"uniform_loss",
41+
"default_loss",
42+
"abs_min_log_loss",
43+
"triangle_loss",
44+
"resolution_loss_function",
45+
"curvature_loss_function",
46+
"Learner1D",
47+
]
3248

3349

3450
@uses_nth_neighbors(0)
35-
def uniform_loss(xs: Point, ys: Any) -> Float:
51+
def uniform_loss(xs: XsType0, ys: YsType0) -> Float:
3652
"""Loss function that samples the domain uniformly.
3753
3854
Works with `~adaptive.Learner1D` only.
@@ -52,10 +68,7 @@ def uniform_loss(xs: Point, ys: Any) -> Float:
5268

5369

5470
@uses_nth_neighbors(0)
55-
def default_loss(
56-
xs: Point,
57-
ys: Union[Tuple[Iterable[Float], Iterable[Float]], Point],
58-
) -> float:
71+
def default_loss(xs: XsType0, ys: YsType0) -> Float:
5972
"""Calculate loss on a single interval.
6073
6174
Currently returns the rescaled length of the interval. If one of the
@@ -64,28 +77,23 @@ def default_loss(
6477
"""
6578
dx = xs[1] - xs[0]
6679
if isinstance(ys[0], collections.abc.Iterable):
67-
dy_vec = [abs(a - b) for a, b in zip(*ys)]
80+
dy_vec = np.array([abs(a - b) for a, b in zip(*ys)])
6881
return np.hypot(dx, dy_vec).max()
6982
else:
7083
dy = ys[1] - ys[0]
7184
return np.hypot(dx, dy)
7285

7386

7487
@uses_nth_neighbors(0)
75-
def abs_min_log_loss(xs, ys):
88+
def abs_min_log_loss(xs: XsType0, ys: YsType0) -> Float:
7689
"""Calculate loss of a single interval that prioritizes the absolute minimum."""
77-
ys = [np.log(np.abs(y).min()) for y in ys]
90+
ys = tuple(np.log(np.abs(y).min()) for y in ys)
7891
return default_loss(xs, ys)
7992

8093

8194
@uses_nth_neighbors(1)
82-
def triangle_loss(
83-
xs: Sequence[Optional[Float]],
84-
ys: Union[
85-
Iterable[Optional[Float]],
86-
Iterable[Union[Iterable[Float], None]],
87-
],
88-
) -> float:
95+
def triangle_loss(xs: XsType1, ys: YsType1) -> Float:
96+
assert len(xs) == 4
8997
xs = [x for x in xs if x is not None]
9098
ys = [y for y in ys if y is not None]
9199

@@ -102,7 +110,9 @@ def triangle_loss(
102110
return sum(vol(pts[i : i + 3]) for i in range(N)) / N
103111

104112

105-
def resolution_loss_function(min_length=0, max_length=1):
113+
def resolution_loss_function(
114+
min_length: Real = 0, max_length: Real = 1
115+
) -> Callable[[XsType0, YsType0], Float]:
106116
"""Loss function that is similar to the `default_loss` function, but you
107117
can set the maximum and minimum size of an interval.
108118
@@ -125,7 +135,7 @@ def resolution_loss_function(min_length=0, max_length=1):
125135
"""
126136

127137
@uses_nth_neighbors(0)
128-
def resolution_loss(xs, ys):
138+
def resolution_loss(xs: XsType0, ys: YsType0) -> Float:
129139
loss = uniform_loss(xs, ys)
130140
if loss < min_length:
131141
# Return zero such that this interval won't be chosen again
@@ -140,11 +150,11 @@ def resolution_loss(xs, ys):
140150

141151

142152
def curvature_loss_function(
143-
area_factor: float = 1, euclid_factor: float = 0.02, horizontal_factor: float = 0.02
144-
) -> Callable:
153+
area_factor: Real = 1, euclid_factor: Real = 0.02, horizontal_factor: Real = 0.02
154+
) -> Callable[[XsType1, YsType1], Float]:
145155
# XXX: add a doc-string
146156
@uses_nth_neighbors(1)
147-
def curvature_loss(xs, ys):
157+
def curvature_loss(xs: XsType1, ys: YsType1) -> Float:
148158
xs_middle = xs[1:3]
149159
ys_middle = ys[1:3]
150160

@@ -160,7 +170,7 @@ def curvature_loss(xs, ys):
160170
return curvature_loss
161171

162172

163-
def linspace(x_left: float, x_right: float, n: int) -> List[float]:
173+
def linspace(x_left: Real, x_right: Real, n: Int) -> List[Float]:
164174
"""This is equivalent to
165175
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
166176
but it is 15-30 times faster for small 'n'."""
@@ -172,7 +182,7 @@ def linspace(x_left: float, x_right: float, n: int) -> List[float]:
172182
return [x_left + step * i for i in range(1, n)]
173183

174184

175-
def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
185+
def _get_neighbors_from_array(xs: np.ndarray) -> NeighborsType:
176186
xs = np.sort(xs)
177187
xs_left = np.roll(xs, 1).tolist()
178188
xs_right = np.roll(xs, -1).tolist()
@@ -182,7 +192,9 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
182192
return SortedDict(neighbors)
183193

184194

185-
def _get_intervals(x: float, neighbors: SortedDict, nth_neighbors: int) -> Any:
195+
def _get_intervals(
196+
x: float, neighbors: NeighborsType, nth_neighbors: int
197+
) -> List[Tuple[float, float]]:
186198
nn = nth_neighbors
187199
i = neighbors.index(x)
188200
start = max(0, i - nn - 1)
@@ -237,10 +249,10 @@ class Learner1D(BaseLearner):
237249

238250
def __init__(
239251
self,
240-
function: Callable,
241-
bounds: Tuple[float, float],
242-
loss_per_interval: Optional[Callable] = None,
243-
) -> None:
252+
function: Callable[[Real], Union[Float, np.ndarray]],
253+
bounds: Tuple[Real, Real],
254+
loss_per_interval: Optional[Callable[[XsTypeN, YsTypeN], Float]] = None,
255+
):
244256
self.function = function # type: ignore
245257

246258
if hasattr(loss_per_interval, "nth_neighbors"):
@@ -255,13 +267,13 @@ def __init__(
255267
# the learners behavior in the tests.
256268
self._recompute_losses_factor = 2
257269

258-
self.data = {}
259-
self.pending_points = set()
270+
self.data: Dict[Real, Real] = {}
271+
self.pending_points: Set[Real] = set()
260272

261273
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
262274
# properties.
263-
self.neighbors = SortedDict()
264-
self.neighbors_combined = SortedDict()
275+
self.neighbors: NeighborsType = SortedDict()
276+
self.neighbors_combined: NeighborsType = SortedDict()
265277

266278
# Bounding box [[minx, maxx], [miny, maxy]].
267279
self._bbox = [list(bounds), [np.inf, -np.inf]]
@@ -319,14 +331,14 @@ def loss(self, real: bool = True) -> float:
319331
max_interval, max_loss = losses.peekitem(0)
320332
return max_loss
321333

322-
def _scale_x(self, x: Optional[float]) -> Optional[float]:
334+
def _scale_x(self, x: Optional[Float]) -> Optional[Float]:
323335
if x is None:
324336
return None
325337
return x / self._scale[0]
326338

327339
def _scale_y(
328-
self, y: Optional[Union[Float, np.ndarray]]
329-
) -> Optional[Union[Float, np.ndarray]]:
340+
self, y: Union[Float, np.ndarray, None]
341+
) -> Union[Float, np.ndarray, None]:
330342
if y is None:
331343
return None
332344
y_scale = self._scale[1] or 1
@@ -418,7 +430,7 @@ def _update_losses(self, x: float, real: bool = True) -> None:
418430
self.losses_combined[x, b] = float("inf")
419431

420432
@staticmethod
421-
def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
433+
def _find_neighbors(x: float, neighbors: NeighborsType) -> Any:
422434
if x in neighbors:
423435
return neighbors[x]
424436
pos = neighbors.bisect_left(x)
@@ -427,7 +439,7 @@ def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
427439
x_right = keys[pos] if pos != len(neighbors) else None
428440
return x_left, x_right
429441

430-
def _update_neighbors(self, x: float, neighbors: SortedDict) -> None:
442+
def _update_neighbors(self, x: float, neighbors: NeighborsType) -> None:
431443
if x not in neighbors: # The point is new
432444
x_left, x_right = self._find_neighbors(x, neighbors)
433445
neighbors[x] = [x_left, x_right]
@@ -461,9 +473,7 @@ def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
461473
self._bbox[1][1] = max(self._bbox[1][1], y)
462474
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
463475

464-
def tell(
465-
self, x: float, y: Union[Float, Sequence[numbers.Number], np.ndarray]
466-
) -> None:
476+
def tell(self, x: float, y: Union[Float, Sequence[Float], np.ndarray]) -> None:
467477
if x in self.data:
468478
# The point is already evaluated before
469479
return
@@ -506,7 +516,17 @@ def tell_pending(self, x: float) -> None:
506516
self._update_neighbors(x, self.neighbors_combined)
507517
self._update_losses(x, real=False)
508518

509-
def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> None:
519+
def tell_many(
520+
self,
521+
xs: Sequence[Float],
522+
ys: Union[
523+
Sequence[Float],
524+
Sequence[Sequence[Float]],
525+
Sequence[np.ndarray],
526+
],
527+
*,
528+
force: bool = False
529+
) -> None:
510530
if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
511531
# Only run this more efficient method if there are
512532
# at least 2 points and the amount of points added are
@@ -526,8 +546,8 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
526546
points_combined = np.hstack([points_pending, points])
527547

528548
# Generate neighbors
529-
self.neighbors = _get_neighbors_from_list(points)
530-
self.neighbors_combined = _get_neighbors_from_list(points_combined)
549+
self.neighbors = _get_neighbors_from_array(points)
550+
self.neighbors_combined = _get_neighbors_from_array(points_combined)
531551

532552
# Update scale
533553
self._bbox[0] = [points_combined.min(), points_combined.max()]
@@ -574,7 +594,7 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
574594
# have an inf loss.
575595
self._update_interpolated_loss_in_interval(*ival)
576596

577-
def ask(self, n: int, tell_pending: bool = True) -> Any:
597+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[float]]:
578598
"""Return 'n' points that are expected to maximally reduce the loss."""
579599
points, loss_improvements = self._ask_points_without_adding(n)
580600

@@ -584,7 +604,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
584604

585605
return points, loss_improvements
586606

587-
def _ask_points_without_adding(self, n: int) -> Any:
607+
def _ask_points_without_adding(self, n: int) -> Tuple[List[float], List[float]]:
588608
"""Return 'n' points that are expected to maximally reduce the loss.
589609
Without altering the state of the learner"""
590610
# Find out how to divide the n points over the intervals
@@ -648,7 +668,7 @@ def _ask_points_without_adding(self, n: int) -> Any:
648668
quals[(*xs, n + 1)] = loss_qual * n / (n + 1)
649669

650670
points = list(
651-
itertools.chain.from_iterable(linspace(a, b, n) for ((a, b), n) in quals)
671+
itertools.chain.from_iterable(linspace(*ival, n) for (*ival, n) in quals)
652672
)
653673

654674
loss_improvements = list(
@@ -663,11 +683,13 @@ def _ask_points_without_adding(self, n: int) -> Any:
663683

664684
return points, loss_improvements
665685

666-
def _loss(self, mapping: ItemSortedDict, ival: Any) -> Any:
686+
def _loss(
687+
self, mapping: Dict[Interval, float], ival: Interval
688+
) -> Tuple[float, Interval]:
667689
loss = mapping[ival]
668690
return finite_loss(ival, loss, self._scale[0])
669691

670-
def plot(self, *, scatter_or_line: Literal["scatter", "line"] = "scatter"):
692+
def plot(self, *, scatter_or_line: str = "scatter"):
671693
"""Returns a plot of the evaluated data.
672694
673695
Parameters
@@ -734,7 +756,7 @@ def __setstate__(self, state):
734756
self.losses_combined.update(losses_combined)
735757

736758

737-
def loss_manager(x_scale: float) -> ItemSortedDict:
759+
def loss_manager(x_scale: float) -> Dict[Interval, float]:
738760
def sort_key(ival, loss):
739761
loss, ival = finite_loss(ival, loss, x_scale)
740762
return -loss, ival
@@ -743,8 +765,8 @@ def sort_key(ival, loss):
743765
return sorted_dict
744766

745767

746-
def finite_loss(ival: Any, loss: float, x_scale: float) -> Any:
747-
"""Get the socalled finite_loss of an interval in order to be able to
768+
def finite_loss(ival: Interval, loss: float, x_scale: float) -> Tuple[float, Interval]:
769+
"""Get the so-called finite_loss of an interval in order to be able to
748770
sort intervals that have infinite loss."""
749771
# If the loss is infinite we return the
750772
# distance between the two points.

0 commit comments

Comments
 (0)