Skip to content

Commit 9bfbeb8

Browse files
committed
add type-hints to Learner1D
1 parent c820704 commit 9bfbeb8

File tree

1 file changed

+88
-48
lines changed

1 file changed

+88
-48
lines changed

adaptive/learner/learner1D.py

Lines changed: 88 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,38 @@
1+
import collections.abc
12
import itertools
23
import math
3-
from collections.abc import Iterable
4+
import numbers
45
from copy import deepcopy
6+
from typing import (
7+
Any,
8+
Callable,
9+
Dict,
10+
Iterable,
11+
List,
12+
Literal,
13+
Optional,
14+
Sequence,
15+
Tuple,
16+
Union,
17+
)
518

619
import cloudpickle
720
import numpy as np
8-
import sortedcollections
9-
import sortedcontainers
21+
from sortedcollections.recipes import ItemSortedDict
22+
from sortedcontainers.sorteddict import SortedDict
1023

1124
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
1225
from adaptive.learner.learnerND import volume
1326
from adaptive.learner.triangulation import simplex_volume_in_embedding
1427
from adaptive.notebook_integration import ensure_holoviews
28+
from adaptive.types import Float
1529
from adaptive.utils import cache_latest
1630

31+
Point = Tuple[Float, Float]
32+
1733

1834
@uses_nth_neighbors(0)
19-
def uniform_loss(xs, ys):
35+
def uniform_loss(xs: Point, ys: Any) -> Float:
2036
"""Loss function that samples the domain uniformly.
2137
2238
Works with `~adaptive.Learner1D` only.
@@ -36,17 +52,20 @@ def uniform_loss(xs, ys):
3652

3753

3854
@uses_nth_neighbors(0)
39-
def default_loss(xs, ys):
55+
def default_loss(
56+
xs: Point,
57+
ys: Union[Tuple[Iterable[Float], Iterable[Float]], Point],
58+
) -> float:
4059
"""Calculate loss on a single interval.
4160
4261
Currently returns the rescaled length of the interval. If one of the
4362
y-values is missing, returns 0 (so the intervals with missing data are
4463
never touched. This behavior should be improved later.
4564
"""
4665
dx = xs[1] - xs[0]
47-
if isinstance(ys[0], Iterable):
48-
dy = [abs(a - b) for a, b in zip(*ys)]
49-
return np.hypot(dx, dy).max()
66+
if isinstance(ys[0], collections.abc.Iterable):
67+
dy_vec = [abs(a - b) for a, b in zip(*ys)]
68+
return np.hypot(dx, dy_vec).max()
5069
else:
5170
dy = ys[1] - ys[0]
5271
return np.hypot(dx, dy)
@@ -60,15 +79,21 @@ def abs_min_log_loss(xs, ys):
6079

6180

6281
@uses_nth_neighbors(1)
63-
def triangle_loss(xs, ys):
82+
def triangle_loss(
83+
xs: Sequence[Optional[Float]],
84+
ys: Union[
85+
Iterable[Optional[Float]],
86+
Iterable[Union[Iterable[Float], None]],
87+
],
88+
) -> float:
6489
xs = [x for x in xs if x is not None]
6590
ys = [y for y in ys if y is not None]
6691

6792
if len(xs) == 2: # we do not have enough points for a triangle
6893
return xs[1] - xs[0]
6994

7095
N = len(xs) - 2 # number of constructed triangles
71-
if isinstance(ys[0], Iterable):
96+
if isinstance(ys[0], collections.abc.Iterable):
7297
pts = [(x, *y) for x, y in zip(xs, ys)]
7398
vol = simplex_volume_in_embedding
7499
else:
@@ -114,7 +139,9 @@ def resolution_loss(xs, ys):
114139
return resolution_loss
115140

116141

117-
def curvature_loss_function(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
142+
def curvature_loss_function(
143+
area_factor: float = 1, euclid_factor: float = 0.02, horizontal_factor: float = 0.02
144+
) -> Callable:
118145
# XXX: add a doc-string
119146
@uses_nth_neighbors(1)
120147
def curvature_loss(xs, ys):
@@ -133,7 +160,7 @@ def curvature_loss(xs, ys):
133160
return curvature_loss
134161

135162

136-
def linspace(x_left, x_right, n):
163+
def linspace(x_left: float, x_right: float, n: int) -> List[float]:
137164
"""This is equivalent to
138165
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
139166
but it is 15-30 times faster for small 'n'."""
@@ -145,17 +172,17 @@ def linspace(x_left, x_right, n):
145172
return [x_left + step * i for i in range(1, n)]
146173

147174

148-
def _get_neighbors_from_list(xs):
175+
def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
149176
xs = np.sort(xs)
150177
xs_left = np.roll(xs, 1).tolist()
151178
xs_right = np.roll(xs, -1).tolist()
152179
xs_left[0] = None
153180
xs_right[-1] = None
154181
neighbors = {x: [x_L, x_R] for x, x_L, x_R in zip(xs, xs_left, xs_right)}
155-
return sortedcontainers.SortedDict(neighbors)
182+
return SortedDict(neighbors)
156183

157184

158-
def _get_intervals(x, neighbors, nth_neighbors):
185+
def _get_intervals(x: float, neighbors: SortedDict, nth_neighbors: int) -> Any:
159186
nn = nth_neighbors
160187
i = neighbors.index(x)
161188
start = max(0, i - nn - 1)
@@ -208,8 +235,13 @@ class Learner1D(BaseLearner):
208235
decorator for more information.
209236
"""
210237

211-
def __init__(self, function, bounds, loss_per_interval=None):
212-
self.function = function
238+
def __init__(
239+
self,
240+
function: Callable,
241+
bounds: Tuple[float, float],
242+
loss_per_interval: Optional[Callable] = None,
243+
) -> None:
244+
self.function = function # type: ignore
213245

214246
if hasattr(loss_per_interval, "nth_neighbors"):
215247
self.nth_neighbors = loss_per_interval.nth_neighbors
@@ -228,8 +260,8 @@ def __init__(self, function, bounds, loss_per_interval=None):
228260

229261
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
230262
# properties.
231-
self.neighbors = sortedcontainers.SortedDict()
232-
self.neighbors_combined = sortedcontainers.SortedDict()
263+
self.neighbors = SortedDict()
264+
self.neighbors_combined = SortedDict()
233265

234266
# Bounding box [[minx, maxx], [miny, maxy]].
235267
self._bbox = [list(bounds), [np.inf, -np.inf]]
@@ -247,10 +279,10 @@ def __init__(self, function, bounds, loss_per_interval=None):
247279

248280
self.bounds = list(bounds)
249281

250-
self._vdim = None
282+
self._vdim: Optional[int] = None
251283

252284
@property
253-
def vdim(self):
285+
def vdim(self) -> int:
254286
"""Length of the output of ``learner.function``.
255287
If the output is unsized (when it's a scalar)
256288
then `vdim = 1`.
@@ -275,35 +307,37 @@ def to_numpy(self):
275307
return np.array([(x, *np.atleast_1d(y)) for x, y in sorted(self.data.items())])
276308

277309
@property
278-
def npoints(self):
310+
def npoints(self) -> int:
279311
"""Number of evaluated points."""
280312
return len(self.data)
281313

282314
@cache_latest
283-
def loss(self, real=True):
315+
def loss(self, real: bool = True) -> float:
284316
losses = self.losses if real else self.losses_combined
285317
if not losses:
286318
return np.inf
287319
max_interval, max_loss = losses.peekitem(0)
288320
return max_loss
289321

290-
def _scale_x(self, x):
322+
def _scale_x(self, x: Optional[float]) -> Optional[float]:
291323
if x is None:
292324
return None
293325
return x / self._scale[0]
294326

295-
def _scale_y(self, y):
327+
def _scale_y(
328+
self, y: Optional[Union[Float, np.ndarray]]
329+
) -> Optional[Union[Float, np.ndarray]]:
296330
if y is None:
297331
return None
298332
y_scale = self._scale[1] or 1
299333
return y / y_scale
300334

301-
def _get_point_by_index(self, ind):
335+
def _get_point_by_index(self, ind: int) -> Optional[float]:
302336
if ind < 0 or ind >= len(self.neighbors):
303337
return None
304338
return self.neighbors.keys()[ind]
305339

306-
def _get_loss_in_interval(self, x_left, x_right):
340+
def _get_loss_in_interval(self, x_left: float, x_right: float) -> float:
307341
assert x_left is not None and x_right is not None
308342

309343
if x_right - x_left < self._dx_eps:
@@ -323,7 +357,9 @@ def _get_loss_in_interval(self, x_left, x_right):
323357
# we need to compute the loss for this interval
324358
return self.loss_per_interval(xs_scaled, ys_scaled)
325359

326-
def _update_interpolated_loss_in_interval(self, x_left, x_right):
360+
def _update_interpolated_loss_in_interval(
361+
self, x_left: float, x_right: float
362+
) -> None:
327363
if x_left is None or x_right is None:
328364
return
329365

@@ -339,7 +375,7 @@ def _update_interpolated_loss_in_interval(self, x_left, x_right):
339375
self.losses_combined[a, b] = (b - a) * loss / dx
340376
a = b
341377

342-
def _update_losses(self, x, real=True):
378+
def _update_losses(self, x: float, real: bool = True) -> None:
343379
"""Update all losses that depend on x"""
344380
# When we add a new point x, we should update the losses
345381
# (x_left, x_right) are the "real" neighbors of 'x'.
@@ -382,7 +418,7 @@ def _update_losses(self, x, real=True):
382418
self.losses_combined[x, b] = float("inf")
383419

384420
@staticmethod
385-
def _find_neighbors(x, neighbors):
421+
def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
386422
if x in neighbors:
387423
return neighbors[x]
388424
pos = neighbors.bisect_left(x)
@@ -391,14 +427,14 @@ def _find_neighbors(x, neighbors):
391427
x_right = keys[pos] if pos != len(neighbors) else None
392428
return x_left, x_right
393429

394-
def _update_neighbors(self, x, neighbors):
430+
def _update_neighbors(self, x: float, neighbors: SortedDict) -> None:
395431
if x not in neighbors: # The point is new
396432
x_left, x_right = self._find_neighbors(x, neighbors)
397433
neighbors[x] = [x_left, x_right]
398434
neighbors.get(x_left, [None, None])[1] = x
399435
neighbors.get(x_right, [None, None])[0] = x
400436

401-
def _update_scale(self, x, y):
437+
def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
402438
"""Update the scale with which the x and y-values are scaled.
403439
404440
For a learner where the function returns a single scalar the scale
@@ -425,7 +461,9 @@ def _update_scale(self, x, y):
425461
self._bbox[1][1] = max(self._bbox[1][1], y)
426462
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
427463

428-
def tell(self, x, y):
464+
def tell(
465+
self, x: float, y: Union[Float, Sequence[numbers.Number], np.ndarray]
466+
) -> None:
429467
if x in self.data:
430468
# The point is already evaluated before
431469
return
@@ -460,15 +498,15 @@ def tell(self, x, y):
460498

461499
self._oldscale = deepcopy(self._scale)
462500

463-
def tell_pending(self, x):
501+
def tell_pending(self, x: float) -> None:
464502
if x in self.data:
465503
# The point is already evaluated before
466504
return
467505
self.pending_points.add(x)
468506
self._update_neighbors(x, self.neighbors_combined)
469507
self._update_losses(x, real=False)
470508

471-
def tell_many(self, xs, ys, *, force=False):
509+
def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> None:
472510
if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
473511
# Only run this more efficient method if there are
474512
# at least 2 points and the amount of points added are
@@ -536,7 +574,7 @@ def tell_many(self, xs, ys, *, force=False):
536574
# have an inf loss.
537575
self._update_interpolated_loss_in_interval(*ival)
538576

539-
def ask(self, n, tell_pending=True):
577+
def ask(self, n: int, tell_pending: bool = True) -> Any:
540578
"""Return 'n' points that are expected to maximally reduce the loss."""
541579
points, loss_improvements = self._ask_points_without_adding(n)
542580

@@ -546,7 +584,7 @@ def ask(self, n, tell_pending=True):
546584

547585
return points, loss_improvements
548586

549-
def _ask_points_without_adding(self, n):
587+
def _ask_points_without_adding(self, n: int) -> Any:
550588
"""Return 'n' points that are expected to maximally reduce the loss.
551589
Without altering the state of the learner"""
552590
# Find out how to divide the n points over the intervals
@@ -573,7 +611,8 @@ def _ask_points_without_adding(self, n):
573611
# Add bound intervals to quals if bounds were missing.
574612
if len(self.data) + len(self.pending_points) == 0:
575613
# We don't have any points, so return a linspace with 'n' points.
576-
return np.linspace(*self.bounds, n).tolist(), [np.inf] * n
614+
a, b = self.bounds
615+
return np.linspace(a, b, n).tolist(), [np.inf] * n
577616

578617
quals = loss_manager(self._scale[0])
579618
if len(missing_bounds) > 0:
@@ -609,7 +648,7 @@ def _ask_points_without_adding(self, n):
609648
quals[(*xs, n + 1)] = loss_qual * n / (n + 1)
610649

611650
points = list(
612-
itertools.chain.from_iterable(linspace(*ival, n) for (*ival, n) in quals)
651+
itertools.chain.from_iterable(linspace(a, b, n) for ((a, b), n) in quals)
613652
)
614653

615654
loss_improvements = list(
@@ -624,11 +663,11 @@ def _ask_points_without_adding(self, n):
624663

625664
return points, loss_improvements
626665

627-
def _loss(self, mapping, ival):
666+
def _loss(self, mapping: ItemSortedDict, ival: Any) -> Any:
628667
loss = mapping[ival]
629668
return finite_loss(ival, loss, self._scale[0])
630669

631-
def plot(self, *, scatter_or_line="scatter"):
670+
def plot(self, *, scatter_or_line: Literal["scatter", "line"] = "scatter"):
632671
"""Returns a plot of the evaluated data.
633672
634673
Parameters
@@ -663,17 +702,18 @@ def plot(self, *, scatter_or_line="scatter"):
663702

664703
return p.redim(x=dict(range=plot_bounds))
665704

666-
def remove_unfinished(self):
705+
def remove_unfinished(self) -> None:
667706
self.pending_points = set()
668707
self.losses_combined = deepcopy(self.losses)
669708
self.neighbors_combined = deepcopy(self.neighbors)
670709

671-
def _get_data(self):
710+
def _get_data(self) -> Dict[float, float]:
672711
return self.data
673712

674-
def _set_data(self, data):
713+
def _set_data(self, data: Dict[float, float]) -> None:
675714
if data:
676-
self.tell_many(*zip(*data.items()))
715+
xs, ys = zip(*data.items())
716+
self.tell_many(xs, ys)
677717

678718
def __getstate__(self):
679719
return (
@@ -694,16 +734,16 @@ def __setstate__(self, state):
694734
self.losses_combined.update(losses_combined)
695735

696736

697-
def loss_manager(x_scale):
737+
def loss_manager(x_scale: float) -> ItemSortedDict:
698738
def sort_key(ival, loss):
699739
loss, ival = finite_loss(ival, loss, x_scale)
700740
return -loss, ival
701741

702-
sorted_dict = sortedcollections.ItemSortedDict(sort_key)
742+
sorted_dict = ItemSortedDict(sort_key)
703743
return sorted_dict
704744

705745

706-
def finite_loss(ival, loss, x_scale):
746+
def finite_loss(ival: Any, loss: float, x_scale: float) -> Any:
707747
"""Get the socalled finite_loss of an interval in order to be able to
708748
sort intervals that have infinite loss."""
709749
# If the loss is infinite we return the

0 commit comments

Comments
 (0)