Skip to content

Commit 17e84e4

Browse files
committed
Merge remote-tracking branch 'origin/master' into mypy
2 parents 3c9d902 + 9860573 commit 17e84e4

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

adaptive/learner/sequence_learner.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
from copy import copy
4-
from typing import Any, Callable, Iterable
4+
from typing import Any, Callable, Iterable, Tuple
55

66
import cloudpickle
7-
import numpy as np
87
from sortedcontainers import SortedDict, SortedSet
98

109
from adaptive.learner.base_learner import BaseLearner
10+
from adaptive.types import Int
1111
from adaptive.utils import assign_defaults, partial_function_from_dataframe
1212

1313
try:
@@ -18,6 +18,14 @@
1818
except ModuleNotFoundError:
1919
with_pandas = False
2020

21+
try:
22+
from typing import TypeAlias
23+
except ImportError:
24+
from typing_extensions import TypeAlias
25+
26+
27+
PointType: TypeAlias = Tuple[Int, Any]
28+
2129

2230
class _IgnoreFirstArgument:
2331
"""Remove the first argument from the call signature.
@@ -32,9 +40,7 @@ class _IgnoreFirstArgument:
3240
def __init__(self, function: Callable) -> None:
3341
self.function = function # type: ignore
3442

35-
def __call__(
36-
self, index_point: tuple[int, float | np.ndarray], *args, **kwargs
37-
) -> float:
43+
def __call__(self, index_point: PointType, *args, **kwargs):
3844
index, point = index_point
3945
return self.function(point, *args, **kwargs)
4046

@@ -85,7 +91,9 @@ def new(self) -> SequenceLearner:
8591
"""Return a new `~adaptive.SequenceLearner` without the data."""
8692
return SequenceLearner(self._original_function, self.sequence)
8793

88-
def ask(self, n: int, tell_pending: bool = True) -> tuple[Any, list[float]]:
94+
def ask(
95+
self, n: int, tell_pending: bool = True
96+
) -> tuple[list[PointType], list[float]]:
8997
indices = []
9098
points = []
9199
loss_improvements = []
@@ -105,31 +113,31 @@ def ask(self, n: int, tell_pending: bool = True) -> tuple[Any, list[float]]:
105113

106114
def loss(self, real: bool = True) -> float:
107115
if not (self._to_do_indices or self.pending_points):
108-
return 0
116+
return 0.0
109117
else:
110118
npoints = self.npoints + (0 if real else len(self.pending_points))
111119
return (self._ntotal - npoints) / self._ntotal
112120

113-
def remove_unfinished(self):
121+
def remove_unfinished(self) -> None:
114122
for i in self.pending_points:
115123
self._to_do_indices.add(i)
116124
self.pending_points = set()
117125

118-
def tell(self, point: tuple[int, Any], value: Any) -> None:
126+
def tell(self, point: PointType, value: Any) -> None:
119127
index, point = point
120128
self.data[index] = value
121129
self.pending_points.discard(index)
122130
self._to_do_indices.discard(index)
123131

124-
def tell_pending(self, point: Any) -> None:
132+
def tell_pending(self, point: PointType) -> None:
125133
index, point = point
126134
self.pending_points.add(index)
127135
self._to_do_indices.discard(index)
128136

129-
def done(self):
137+
def done(self) -> bool:
130138
return not self._to_do_indices and not self.pending_points
131139

132-
def result(self):
140+
def result(self) -> list[Any]:
133141
"""Get the function values in the same order as ``sequence``."""
134142
if not self.done():
135143
raise Exception("Learner is not yet complete.")
@@ -217,16 +225,18 @@ def load_dataframe(
217225
y_name : str, optional
218226
The ``y_name`` used in ``to_dataframe``, by default "y"
219227
"""
220-
self.tell_many(df[[index_name, x_name]].values, df[y_name].values)
228+
indices = df[index_name].values
229+
xs = df[x_name].values
230+
self.tell_many(zip(indices, xs), df[y_name].values)
221231
if with_default_function_args:
222232
self.function = partial_function_from_dataframe(
223233
self._original_function, df, function_prefix
224234
)
225235

226-
def _get_data(self) -> SortedDict:
236+
def _get_data(self) -> dict[int, Any]:
227237
return self.data
228238

229-
def _set_data(self, data: SortedDict) -> None:
239+
def _set_data(self, data: dict[int, Any]) -> None:
230240
if data:
231241
indices, values = zip(*data.items())
232242
# the points aren't used by tell, so we can safely pass None

0 commit comments

Comments
 (0)