Skip to content

Commit 614f599

Browse files
committed
add type annotations
1 parent dfeb5ef commit 614f599

17 files changed

+608
-390
lines changed

adaptive/_version.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import subprocess
55
from collections import namedtuple
66
from distutils.command.build_py import build_py as build_py_orig
7+
from typing import Dict
78

89
from setuptools.command.sdist import sdist as sdist_orig
910

@@ -19,7 +20,7 @@
1920
STATIC_VERSION_FILE = "_static_version.py"
2021

2122

22-
def get_version(version_file=STATIC_VERSION_FILE):
23+
def get_version(version_file: str = STATIC_VERSION_FILE) -> str:
2324
version_info = get_static_version_info(version_file)
2425
version = version_info["version"]
2526
if version == "__use_git__":
@@ -33,7 +34,7 @@ def get_version(version_file=STATIC_VERSION_FILE):
3334
return version
3435

3536

36-
def get_static_version_info(version_file=STATIC_VERSION_FILE):
37+
def get_static_version_info(version_file: str = STATIC_VERSION_FILE) -> Dict[str, str]:
3738
version_info = {}
3839
with open(os.path.join(package_root, version_file), "rb") as f:
3940
exec(f.read(), {}, version_info)
@@ -44,7 +45,7 @@ def version_is_from_git(version_file=STATIC_VERSION_FILE):
4445
return get_static_version_info(version_file)["version"] == "__use_git__"
4546

4647

47-
def pep440_format(version_info):
48+
def pep440_format(version_info: Version) -> str:
4849
release, dev, labels = version_info
4950

5051
version_parts = [release]
@@ -61,7 +62,7 @@ def pep440_format(version_info):
6162
return "".join(version_parts)
6263

6364

64-
def get_version_from_git():
65+
def get_version_from_git() -> Version:
6566
try:
6667
p = subprocess.Popen(
6768
["git", "rev-parse", "--show-toplevel"],

adaptive/learner/average_learner.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from math import sqrt
2+
from typing import Callable, Dict, List, Optional, Tuple
23

34
import numpy as np
45

@@ -30,7 +31,12 @@ class AverageLearner(BaseLearner):
3031
Number of evaluated points.
3132
"""
3233

33-
def __init__(self, function, atol=None, rtol=None):
34+
def __init__(
35+
self,
36+
function: Callable,
37+
atol: Optional[float] = None,
38+
rtol: Optional[float] = None,
39+
) -> None:
3440
if atol is None and rtol is None:
3541
raise Exception("At least one of `atol` and `rtol` should be set.")
3642
if atol is None:
@@ -48,10 +54,10 @@ def __init__(self, function, atol=None, rtol=None):
4854
self.sum_f_sq = 0
4955

5056
@property
51-
def n_requested(self):
57+
def n_requested(self) -> int:
5258
return self.npoints + len(self.pending_points)
5359

54-
def ask(self, n, tell_pending=True):
60+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[int], List[float]]:
5561
points = list(range(self.n_requested, self.n_requested + n))
5662

5763
if any(p in self.data or p in self.pending_points for p in points):
@@ -68,7 +74,7 @@ def ask(self, n, tell_pending=True):
6874
self.tell_pending(p)
6975
return points, loss_improvements
7076

71-
def tell(self, n, value):
77+
def tell(self, n: int, value: float) -> None:
7278
if n in self.data:
7379
# The point has already been added before.
7480
return
@@ -79,16 +85,16 @@ def tell(self, n, value):
7985
self.sum_f_sq += value ** 2
8086
self.npoints += 1
8187

82-
def tell_pending(self, n):
88+
def tell_pending(self, n: int) -> None:
8389
self.pending_points.add(n)
8490

8591
@property
86-
def mean(self):
92+
def mean(self) -> float:
8793
"""The average of all values in `data`."""
8894
return self.sum_f / self.npoints
8995

9096
@property
91-
def std(self):
97+
def std(self) -> float:
9298
"""The corrected sample standard deviation of the values
9399
in `data`."""
94100
n = self.npoints
@@ -101,7 +107,7 @@ def std(self):
101107
return sqrt(numerator / (n - 1))
102108

103109
@cache_latest
104-
def loss(self, real=True, *, n=None):
110+
def loss(self, real: bool = True, *, n=None) -> float:
105111
if n is None:
106112
n = self.npoints if real else self.n_requested
107113
else:
@@ -113,7 +119,7 @@ def loss(self, real=True, *, n=None):
113119
standard_error / self.atol, standard_error / abs(self.mean) / self.rtol
114120
)
115121

116-
def _loss_improvement(self, n):
122+
def _loss_improvement(self, n: int) -> float:
117123
loss = self.loss()
118124
if np.isfinite(loss):
119125
return loss - self.loss(n=self.npoints + n)
@@ -139,8 +145,8 @@ def plot(self):
139145
vals = hv.Points(vals)
140146
return hv.operation.histogram(vals, num_bins=num_bins, dimension=1)
141147

142-
def _get_data(self):
148+
def _get_data(self) -> Tuple[Dict[int, float], int, float, float]:
143149
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)
144150

145-
def _set_data(self, data):
151+
def _set_data(self, data: Tuple[Dict[int, float], int, float, float]) -> None:
146152
self.data, self.npoints, self.sum_f, self.sum_f_sq = data

adaptive/learner/balancing_learner.py

+32-17
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from contextlib import suppress
55
from functools import partial
66
from operator import itemgetter
7+
from typing import Any, Callable, Dict, List, Set, Tuple, Union
78

89
import numpy as np
910

@@ -12,7 +13,7 @@
1213
from adaptive.utils import cache_latest, named_product, restore
1314

1415

15-
def dispatch(child_functions, arg):
16+
def dispatch(child_functions: List[Callable], arg: Any) -> Union[Any]:
1617
index, x = arg
1718
return child_functions[index](x)
1819

@@ -68,7 +69,9 @@ class BalancingLearner(BaseLearner):
6869
behave in an undefined way. Change the `strategy` in that case.
6970
"""
7071

71-
def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
72+
def __init__(
73+
self, learners: List[BaseLearner], *, cdims=None, strategy="loss_improvements"
74+
) -> None:
7275
self.learners = learners
7376

7477
# Naively we would make 'function' a method, but this causes problems
@@ -89,21 +92,21 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
8992
self.strategy = strategy
9093

9194
@property
92-
def data(self):
95+
def data(self) -> Dict[Tuple[int, Any], Any]:
9396
data = {}
9497
for i, l in enumerate(self.learners):
9598
data.update({(i, p): v for p, v in l.data.items()})
9699
return data
97100

98101
@property
99-
def pending_points(self):
102+
def pending_points(self) -> Set[Tuple[int, Any]]:
100103
pending_points = set()
101104
for i, l in enumerate(self.learners):
102105
pending_points.update({(i, p) for p in l.pending_points})
103106
return pending_points
104107

105108
@property
106-
def npoints(self):
109+
def npoints(self) -> int:
107110
return sum(l.npoints for l in self.learners)
108111

109112
@property
@@ -135,7 +138,9 @@ def strategy(self, strategy):
135138
' strategy="npoints", or strategy="cycle" is implemented.'
136139
)
137140

138-
def _ask_and_tell_based_on_loss_improvements(self, n):
141+
def _ask_and_tell_based_on_loss_improvements(
142+
self, n: int
143+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
139144
selected = [] # tuples ((learner_index, point), loss_improvement)
140145
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
141146
for _ in range(n):
@@ -158,7 +163,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
158163
points, loss_improvements = map(list, zip(*selected))
159164
return points, loss_improvements
160165

161-
def _ask_and_tell_based_on_loss(self, n):
166+
def _ask_and_tell_based_on_loss(
167+
self, n: int
168+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
162169
selected = [] # tuples ((learner_index, point), loss_improvement)
163170
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
164171
for _ in range(n):
@@ -179,7 +186,9 @@ def _ask_and_tell_based_on_loss(self, n):
179186
points, loss_improvements = map(list, zip(*selected))
180187
return points, loss_improvements
181188

182-
def _ask_and_tell_based_on_npoints(self, n):
189+
def _ask_and_tell_based_on_npoints(
190+
self, n: int
191+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
183192
selected = [] # tuples ((learner_index, point), loss_improvement)
184193
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
185194
for _ in range(n):
@@ -195,7 +204,9 @@ def _ask_and_tell_based_on_npoints(self, n):
195204
points, loss_improvements = map(list, zip(*selected))
196205
return points, loss_improvements
197206

198-
def _ask_and_tell_based_on_cycle(self, n):
207+
def _ask_and_tell_based_on_cycle(
208+
self, n: int
209+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
199210
points, loss_improvements = [], []
200211
for _ in range(n):
201212
index = next(self._cycle)
@@ -206,7 +217,9 @@ def _ask_and_tell_based_on_cycle(self, n):
206217

207218
return points, loss_improvements
208219

209-
def ask(self, n, tell_pending=True):
220+
def ask(
221+
self, n: int, tell_pending: bool = True
222+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
210223
"""Chose points for learners."""
211224
if n == 0:
212225
return [], []
@@ -217,20 +230,20 @@ def ask(self, n, tell_pending=True):
217230
else:
218231
return self._ask_and_tell(n)
219232

220-
def tell(self, x, y):
233+
def tell(self, x: Tuple[int, Any], y: Any) -> None:
221234
index, x = x
222235
self._ask_cache.pop(index, None)
223236
self._loss.pop(index, None)
224237
self._pending_loss.pop(index, None)
225238
self.learners[index].tell(x, y)
226239

227-
def tell_pending(self, x):
240+
def tell_pending(self, x: Tuple[int, Any]) -> None:
228241
index, x = x
229242
self._ask_cache.pop(index, None)
230243
self._loss.pop(index, None)
231244
self.learners[index].tell_pending(x)
232245

233-
def _losses(self, real=True):
246+
def _losses(self, real: bool = True) -> List[float]:
234247
losses = []
235248
loss_dict = self._loss if real else self._pending_loss
236249

@@ -242,7 +255,7 @@ def _losses(self, real=True):
242255
return losses
243256

244257
@cache_latest
245-
def loss(self, real=True):
258+
def loss(self, real: bool = True) -> Union[float]:
246259
losses = self._losses(real)
247260
return max(losses)
248261

@@ -325,7 +338,9 @@ def remove_unfinished(self):
325338
learner.remove_unfinished()
326339

327340
@classmethod
328-
def from_product(cls, f, learner_type, learner_kwargs, combos):
341+
def from_product(
342+
cls, f, learner_type, learner_kwargs, combos
343+
) -> "BalancingLearner":
329344
"""Create a `BalancingLearner` with learners of all combinations of
330345
named variables’ values. The `cdims` will be set correctly, so calling
331346
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -372,7 +387,7 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
372387
learners.append(learner)
373388
return cls(learners, cdims=arguments)
374389

375-
def save(self, fname, compress=True):
390+
def save(self, fname: Callable, compress: bool = True) -> None:
376391
"""Save the data of the child learners into pickle files
377392
in a directory.
378393
@@ -410,7 +425,7 @@ def save(self, fname, compress=True):
410425
for l in self.learners:
411426
l.save(fname(l), compress=compress)
412427

413-
def load(self, fname, compress=True):
428+
def load(self, fname: Callable, compress: bool = True) -> None:
414429
"""Load the data of the child learners from pickle files
415430
in a directory.
416431

adaptive/learner/base_learner.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import abc
22
from contextlib import suppress
33
from copy import deepcopy
4+
from typing import Any, Callable, Dict
45

56
from adaptive.utils import _RequireAttrsABCMeta, load, save
67

78

8-
def uses_nth_neighbors(n):
9+
def uses_nth_neighbors(n: int) -> Callable:
910
"""Decorator to specify how many neighboring intervals the loss function uses.
1011
1112
Wraps loss functions to indicate that they expect intervals together
@@ -84,7 +85,7 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
8485
npoints: int
8586
pending_points: set
8687

87-
def tell(self, x, y):
88+
def tell(self, x: Any, y) -> None:
8889
"""Tell the learner about a single value.
8990
9091
Parameters
@@ -94,7 +95,7 @@ def tell(self, x, y):
9495
"""
9596
self.tell_many([x], [y])
9697

97-
def tell_many(self, xs, ys):
98+
def tell_many(self, xs: Any, ys: Any) -> None:
9899
"""Tell the learner about some values.
99100
100101
Parameters
@@ -161,7 +162,7 @@ def copy_from(self, other):
161162
"""
162163
self._set_data(other._get_data())
163164

164-
def save(self, fname, compress=True):
165+
def save(self, fname: str, compress: bool = True) -> None:
165166
"""Save the data of the learner into a pickle file.
166167
167168
Parameters
@@ -175,7 +176,7 @@ def save(self, fname, compress=True):
175176
data = self._get_data()
176177
save(fname, data, compress)
177178

178-
def load(self, fname, compress=True):
179+
def load(self, fname: str, compress: bool = True) -> None:
179180
"""Load the data of a learner from a pickle file.
180181
181182
Parameters
@@ -190,8 +191,8 @@ def load(self, fname, compress=True):
190191
data = load(fname, compress)
191192
self._set_data(data)
192193

193-
def __getstate__(self):
194+
def __getstate__(self) -> Dict[str, Any]:
194195
return deepcopy(self.__dict__)
195196

196-
def __setstate__(self, state):
197+
def __setstate__(self, state: Dict[str, Any]) -> None:
197198
self.__dict__ = state

0 commit comments

Comments
 (0)