Skip to content

Commit 4cc00a7

Browse files
authored
use 'from __future__ import annotations' (#346)
1 parent d1b0b2a commit 4cc00a7

File tree

5 files changed

+54
-70
lines changed

5 files changed

+54
-70
lines changed

adaptive/learner/average_learner.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from __future__ import annotations
2+
13
from math import sqrt
2-
from typing import Callable, Dict, List, Optional, Tuple
4+
from typing import Callable
35

46
import cloudpickle
57
import numpy as np
@@ -38,8 +40,8 @@ class AverageLearner(BaseLearner):
3840
def __init__(
3941
self,
4042
function: Callable[[int], Real],
41-
atol: Optional[float] = None,
42-
rtol: Optional[float] = None,
43+
atol: float | None = None,
44+
rtol: float | None = None,
4345
min_npoints: int = 2,
4446
) -> None:
4547
if atol is None and rtol is None:
@@ -68,7 +70,7 @@ def to_numpy(self):
6870
"""Data as NumPy array of size (npoints, 2) with seeds and values."""
6971
return np.array(sorted(self.data.items()))
7072

71-
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[int], List[Float]]:
73+
def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]]:
7274
points = list(range(self.n_requested, self.n_requested + n))
7375

7476
if any(p in self.data or p in self.pending_points for p in points):
@@ -159,10 +161,10 @@ def plot(self):
159161
vals = hv.Points(vals)
160162
return hv.operation.histogram(vals, num_bins=num_bins, dimension="y")
161163

162-
def _get_data(self) -> Tuple[Dict[int, Real], int, Real, Real]:
164+
def _get_data(self) -> tuple[dict[int, Real], int, Real, Real]:
163165
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)
164166

165-
def _set_data(self, data: Tuple[Dict[int, Real], int, Real, Real]) -> None:
167+
def _set_data(self, data: tuple[dict[int, Real], int, Real, Real]) -> None:
166168
self.data, self.npoints, self.sum_f, self.sum_f_sq = data
167169

168170
def __getstate__(self):

adaptive/learner/average_learner1D.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,11 @@
1+
from __future__ import annotations
2+
13
import math
24
import sys
35
from collections import defaultdict
46
from copy import deepcopy
57
from math import hypot
6-
from typing import (
7-
Callable,
8-
DefaultDict,
9-
Dict,
10-
Iterable,
11-
List,
12-
Optional,
13-
Sequence,
14-
Set,
15-
Tuple,
16-
)
8+
from typing import Callable, DefaultDict, Iterable, List, Sequence, Tuple
179

1810
import numpy as np
1911
import scipy.stats
@@ -27,7 +19,7 @@
2719
Point = Tuple[int, Real]
2820
Points = List[Point]
2921

30-
__all__: List[str] = ["AverageLearner1D"]
22+
__all__: list[str] = ["AverageLearner1D"]
3123

3224

3325
class AverageLearner1D(Learner1D):
@@ -75,11 +67,10 @@ class AverageLearner1D(Learner1D):
7567

7668
def __init__(
7769
self,
78-
function: Callable[[Tuple[int, Real]], Real],
79-
bounds: Tuple[Real, Real],
80-
loss_per_interval: Optional[
81-
Callable[[Sequence[Real], Sequence[Real]], float]
82-
] = None,
70+
function: Callable[[tuple[int, Real]], Real],
71+
bounds: tuple[Real, Real],
72+
loss_per_interval: None
73+
| (Callable[[Sequence[Real], Sequence[Real]], float]) = None,
8374
delta: float = 0.2,
8475
alpha: float = 0.005,
8576
neighbor_sampling: float = 0.3,
@@ -115,15 +106,15 @@ def __init__(
115106
self._number_samples = SortedDict()
116107
# This set contains the points x that have less than min_samples
117108
# samples or less than a (neighbor_sampling*100)% of their neighbors
118-
self._undersampled_points: Set[Real] = set()
109+
self._undersampled_points: set[Real] = set()
119110
# Contains the error in the estimate of the
120111
# mean at each point x in the form {x0: error(x0), ...}
121-
self.error: Dict[Real, float] = decreasing_dict()
112+
self.error: dict[Real, float] = decreasing_dict()
122113
#  Distance between two neighboring points in the
123114
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
124-
self._distances: Dict[Real, float] = decreasing_dict()
115+
self._distances: dict[Real, float] = decreasing_dict()
125116
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
126-
self.rescaled_error: Dict[Real, float] = decreasing_dict()
117+
self.rescaled_error: dict[Real, float] = decreasing_dict()
127118

128119
@property
129120
def nsamples(self) -> int:
@@ -136,7 +127,7 @@ def min_samples_per_point(self) -> int:
136127
return 0
137128
return min(self._number_samples.values())
138129

139-
def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
130+
def ask(self, n: int, tell_pending: bool = True) -> tuple[Points, list[float]]:
140131
"""Return 'n' points that are expected to maximally reduce the loss."""
141132
# If some point is undersampled, resample it
142133
if len(self._undersampled_points):
@@ -165,7 +156,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
165156

166157
return points, loss_improvements
167158

168-
def _ask_for_more_samples(self, x: Real, n: int) -> Tuple[Points, List[float]]:
159+
def _ask_for_more_samples(self, x: Real, n: int) -> tuple[Points, list[float]]:
169160
"""When asking for n points, the learner returns n times an existing point
170161
to be resampled, since in general n << min_samples and this point will
171162
need to be resampled many more times"""
@@ -184,7 +175,7 @@ def _ask_for_more_samples(self, x: Real, n: int) -> Tuple[Points, List[float]]:
184175
loss_improvements = [loss_improvement / n] * n
185176
return points, loss_improvements
186177

187-
def _ask_for_new_point(self, n: int) -> Tuple[Points, List[float]]:
178+
def _ask_for_new_point(self, n: int) -> tuple[Points, list[float]]:
188179
"""When asking for n new points, the learner returns n times a single
189180
new point, since in general n << min_samples and this point will need
190181
to be resampled many more times"""
@@ -398,7 +389,7 @@ def tell_many(self, xs: Points, ys: Sequence[Real]) -> None:
398389
# simultaneously, before we move on to a new x
399390
self.tell_many_at_point(x, seed_y_mapping)
400391

401-
def tell_many_at_point(self, x: Real, seed_y_mapping: Dict[int, Real]) -> None:
392+
def tell_many_at_point(self, x: Real, seed_y_mapping: dict[int, Real]) -> None:
402393
"""Tell the learner about many samples at a certain location x.
403394
404395
Parameters
@@ -454,10 +445,10 @@ def tell_many_at_point(self, x: Real, seed_y_mapping: Dict[int, Real]) -> None:
454445
self._update_interpolated_loss_in_interval(*interval)
455446
self._oldscale = deepcopy(self._scale)
456447

457-
def _get_data(self) -> Dict[Real, Real]:
448+
def _get_data(self) -> dict[Real, Real]:
458449
return self._data_samples
459450

460-
def _set_data(self, data: Dict[Real, Real]) -> None:
451+
def _set_data(self, data: dict[Real, Real]) -> None:
461452
if data:
462453
for x, samples in data.items():
463454
self.tell_many_at_point(x, samples)
@@ -491,7 +482,7 @@ def plot(self):
491482
return p.redim(x=dict(range=plot_bounds))
492483

493484

494-
def decreasing_dict() -> Dict:
485+
def decreasing_dict() -> dict:
495486
"""This initialization orders the dictionary from large to small values"""
496487

497488
def sorting_rule(key, value):

adaptive/learner/base_learner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,10 @@ def tell_many(self, xs, ys):
110110
def tell_pending(self, x):
111111
"""Tell the learner that 'x' has been requested such
112112
that it's not suggested again."""
113-
pass
114113

115114
@abc.abstractmethod
116115
def remove_unfinished(self):
117116
"""Remove uncomputed data from the learner."""
118-
pass
119117

120118
@abc.abstractmethod
121119
def loss(self, real=True):
@@ -142,7 +140,6 @@ def ask(self, n, tell_pending=True):
142140
`pending_points`. Set this to False if you do not
143141
want to modify the state of the learner.
144142
"""
145-
pass
146143

147144
@abc.abstractmethod
148145
def _get_data(self):

adaptive/learner/learner1D.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from __future__ import annotations
2+
13
import collections.abc
24
import itertools
35
import math
46
from copy import copy, deepcopy
5-
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
7+
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
68

79
import cloudpickle
810
import numpy as np
@@ -170,7 +172,7 @@ def curvature_loss(xs: XsType1, ys: YsType1) -> Float:
170172
return curvature_loss
171173

172174

173-
def linspace(x_left: Real, x_right: Real, n: Int) -> List[Float]:
175+
def linspace(x_left: Real, x_right: Real, n: Int) -> list[Float]:
174176
"""This is equivalent to
175177
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
176178
but it is 15-30 times faster for small 'n'."""
@@ -194,7 +196,7 @@ def _get_neighbors_from_array(xs: np.ndarray) -> NeighborsType:
194196

195197
def _get_intervals(
196198
x: float, neighbors: NeighborsType, nth_neighbors: int
197-
) -> List[Tuple[float, float]]:
199+
) -> list[tuple[float, float]]:
198200
nn = nth_neighbors
199201
i = neighbors.index(x)
200202
start = max(0, i - nn - 1)
@@ -249,9 +251,9 @@ class Learner1D(BaseLearner):
249251

250252
def __init__(
251253
self,
252-
function: Callable[[Real], Union[Float, np.ndarray]],
253-
bounds: Tuple[Real, Real],
254-
loss_per_interval: Optional[Callable[[XsTypeN, YsTypeN], Float]] = None,
254+
function: Callable[[Real], Float | np.ndarray],
255+
bounds: tuple[Real, Real],
256+
loss_per_interval: Callable[[XsTypeN, YsTypeN], Float] | None = None,
255257
):
256258
self.function = function # type: ignore
257259

@@ -267,8 +269,8 @@ def __init__(
267269
# the learners behavior in the tests.
268270
self._recompute_losses_factor = 2
269271

270-
self.data: Dict[Real, Real] = {}
271-
self.pending_points: Set[Real] = set()
272+
self.data: dict[Real, Real] = {}
273+
self.pending_points: set[Real] = set()
272274

273275
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
274276
# properties.
@@ -292,7 +294,7 @@ def __init__(
292294
self.bounds = list(bounds)
293295
self.__missing_bounds = set(self.bounds) # cache of missing bounds
294296

295-
self._vdim: Optional[int] = None
297+
self._vdim: int | None = None
296298

297299
@property
298300
def vdim(self) -> int:
@@ -334,20 +336,18 @@ def loss(self, real: bool = True) -> float:
334336
max_interval, max_loss = losses.peekitem(0)
335337
return max_loss
336338

337-
def _scale_x(self, x: Optional[Float]) -> Optional[Float]:
339+
def _scale_x(self, x: Float | None) -> Float | None:
338340
if x is None:
339341
return None
340342
return x / self._scale[0]
341343

342-
def _scale_y(
343-
self, y: Union[Float, np.ndarray, None]
344-
) -> Union[Float, np.ndarray, None]:
344+
def _scale_y(self, y: Float | np.ndarray | None) -> Float | np.ndarray | None:
345345
if y is None:
346346
return None
347347
y_scale = self._scale[1] or 1
348348
return y / y_scale
349349

350-
def _get_point_by_index(self, ind: int) -> Optional[float]:
350+
def _get_point_by_index(self, ind: int) -> float | None:
351351
if ind < 0 or ind >= len(self.neighbors):
352352
return None
353353
return self.neighbors.keys()[ind]
@@ -449,7 +449,7 @@ def _update_neighbors(self, x: float, neighbors: NeighborsType) -> None:
449449
neighbors.get(x_left, [None, None])[1] = x
450450
neighbors.get(x_right, [None, None])[0] = x
451451

452-
def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
452+
def _update_scale(self, x: float, y: Float | np.ndarray) -> None:
453453
"""Update the scale with which the x and y-values are scaled.
454454
455455
For a learner where the function returns a single scalar the scale
@@ -476,7 +476,7 @@ def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
476476
self._bbox[1][1] = max(self._bbox[1][1], y)
477477
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
478478

479-
def tell(self, x: float, y: Union[Float, Sequence[Float], np.ndarray]) -> None:
479+
def tell(self, x: float, y: Float | Sequence[Float] | np.ndarray) -> None:
480480
if x in self.data:
481481
# The point is already evaluated before
482482
return
@@ -522,13 +522,9 @@ def tell_pending(self, x: float) -> None:
522522
def tell_many(
523523
self,
524524
xs: Sequence[Float],
525-
ys: Union[
526-
Sequence[Float],
527-
Sequence[Sequence[Float]],
528-
Sequence[np.ndarray],
529-
],
525+
ys: (Sequence[Float] | Sequence[Sequence[Float]] | Sequence[np.ndarray]),
530526
*,
531-
force: bool = False
527+
force: bool = False,
532528
) -> None:
533529
if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
534530
# Only run this more efficient method if there are
@@ -597,7 +593,7 @@ def tell_many(
597593
# have an inf loss.
598594
self._update_interpolated_loss_in_interval(*ival)
599595

600-
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[float]]:
596+
def ask(self, n: int, tell_pending: bool = True) -> tuple[list[float], list[float]]:
601597
"""Return 'n' points that are expected to maximally reduce the loss."""
602598
points, loss_improvements = self._ask_points_without_adding(n)
603599

@@ -607,7 +603,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[floa
607603

608604
return points, loss_improvements
609605

610-
def _missing_bounds(self) -> List[Real]:
606+
def _missing_bounds(self) -> list[Real]:
611607
missing_bounds = []
612608
for b in copy(self.__missing_bounds):
613609
if b in self.data:
@@ -616,7 +612,7 @@ def _missing_bounds(self) -> List[Real]:
616612
missing_bounds.append(b)
617613
return sorted(missing_bounds)
618614

619-
def _ask_points_without_adding(self, n: int) -> Tuple[List[float], List[float]]:
615+
def _ask_points_without_adding(self, n: int) -> tuple[list[float], list[float]]:
620616
"""Return 'n' points that are expected to maximally reduce the loss.
621617
Without altering the state of the learner"""
622618
# Find out how to divide the n points over the intervals
@@ -691,8 +687,8 @@ def _ask_points_without_adding(self, n: int) -> Tuple[List[float], List[float]]:
691687
return points, loss_improvements
692688

693689
def _loss(
694-
self, mapping: Dict[Interval, float], ival: Interval
695-
) -> Tuple[float, Interval]:
690+
self, mapping: dict[Interval, float], ival: Interval
691+
) -> tuple[float, Interval]:
696692
loss = mapping[ival]
697693
return finite_loss(ival, loss, self._scale[0])
698694

@@ -736,10 +732,10 @@ def remove_unfinished(self) -> None:
736732
self.losses_combined = deepcopy(self.losses)
737733
self.neighbors_combined = deepcopy(self.neighbors)
738734

739-
def _get_data(self) -> Dict[float, float]:
735+
def _get_data(self) -> dict[float, float]:
740736
return self.data
741737

742-
def _set_data(self, data: Dict[float, float]) -> None:
738+
def _set_data(self, data: dict[float, float]) -> None:
743739
if data:
744740
xs, ys = zip(*data.items())
745741
self.tell_many(xs, ys)
@@ -763,7 +759,7 @@ def __setstate__(self, state):
763759
self.losses_combined.update(losses_combined)
764760

765761

766-
def loss_manager(x_scale: float) -> Dict[Interval, float]:
762+
def loss_manager(x_scale: float) -> dict[Interval, float]:
767763
def sort_key(ival, loss):
768764
loss, ival = finite_loss(ival, loss, x_scale)
769765
return -loss, ival
@@ -772,7 +768,7 @@ def sort_key(ival, loss):
772768
return sorted_dict
773769

774770

775-
def finite_loss(ival: Interval, loss: float, x_scale: float) -> Tuple[float, Interval]:
771+
def finite_loss(ival: Interval, loss: float, x_scale: float) -> tuple[float, Interval]:
776772
"""Get the so-called finite_loss of an interval in order to be able to
777773
sort intervals that have infinite loss."""
778774
# If the loss is infinite we return the

adaptive/runner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,10 @@ def elapsed_time(self):
286286
287287
Is called in `overhead`.
288288
"""
289-
pass
290289

291290
@abc.abstractmethod
292291
def _submit(self, x):
293292
"""Is called in `_get_futures`."""
294-
pass
295293

296294
@property
297295
def tracebacks(self):

0 commit comments

Comments
 (0)