Skip to content

Commit bb9dd07

Browse files
authored
Merge pull request #193 from python-adaptive/sequence_learner
Add a SequenceLearner
2 parents 8b4b583 + de6eb19 commit bb9dd07

File tree

8 files changed

+256
-12
lines changed

8 files changed

+256
-12
lines changed

adaptive/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Learner2D,
1515
LearnerND,
1616
make_datasaver,
17+
SequenceLearner,
1718
)
1819
from adaptive.notebook_integration import (
1920
active_plotting_tasks,
@@ -36,6 +37,7 @@
3637
"Learner2D",
3738
"LearnerND",
3839
"make_datasaver",
40+
"SequenceLearner",
3941
"active_plotting_tasks",
4042
"live_plot",
4143
"notebook_extension",

adaptive/learner/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from adaptive.learner.learner1D import Learner1D
1111
from adaptive.learner.learner2D import Learner2D
1212
from adaptive.learner.learnerND import LearnerND
13+
from adaptive.learner.sequence_learner import SequenceLearner
1314

1415
__all__ = [
1516
"AverageLearner",
@@ -21,6 +22,7 @@
2122
"Learner1D",
2223
"Learner2D",
2324
"LearnerND",
25+
"SequenceLearner",
2426
]
2527

2628
with suppress(ImportError):

adaptive/learner/sequence_learner.py

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from copy import copy
2+
3+
from sortedcontainers import SortedSet, SortedDict
4+
5+
from adaptive.learner.base_learner import BaseLearner
6+
7+
8+
class _IgnoreFirstArgument:
9+
"""Remove the first argument from the call signature.
10+
11+
The SequenceLearner's function receives a tuple ``(index, point)``
12+
but the original function only takes ``point``.
13+
14+
This is the same as `lambda x: function(x[1])`, however, that is not
15+
pickable.
16+
"""
17+
18+
def __init__(self, function):
19+
self.function = function
20+
21+
def __call__(self, index_point, *args, **kwargs):
22+
index, point = index_point
23+
return self.function(point, *args, **kwargs)
24+
25+
def __getstate__(self):
26+
return self.function
27+
28+
def __setstate__(self, function):
29+
self.__init__(function)
30+
31+
32+
class SequenceLearner(BaseLearner):
33+
r"""A learner that will learn a sequence. It simply returns
34+
the points in the provided sequence when asked.
35+
36+
This is useful when your problem cannot be formulated in terms of
37+
another adaptive learner, but you still want to use Adaptive's
38+
routines to run, save, and plot.
39+
40+
Parameters
41+
----------
42+
function : callable
43+
The function to learn. Must take a single element `sequence`.
44+
sequence : sequence
45+
The sequence to learn.
46+
47+
Attributes
48+
----------
49+
data : dict
50+
The data as a mapping from "index of element in sequence" => value.
51+
52+
Notes
53+
-----
54+
From primitive tests, the `~adaptive.SequenceLearner` appears to have a
55+
similar performance to `ipyparallel`\s ``load_balanced_view().map``. With
56+
the added benefit of having results in the local kernel already.
57+
"""
58+
59+
def __init__(self, function, sequence):
60+
self._original_function = function
61+
self.function = _IgnoreFirstArgument(function)
62+
self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)})
63+
self._ntotal = len(sequence)
64+
self.sequence = copy(sequence)
65+
self.data = SortedDict()
66+
self.pending_points = set()
67+
68+
def ask(self, n, tell_pending=True):
69+
indices = []
70+
points = []
71+
loss_improvements = []
72+
for index in self._to_do_indices:
73+
if len(points) >= n:
74+
break
75+
point = self.sequence[index]
76+
indices.append(index)
77+
points.append((index, point))
78+
loss_improvements.append(1 / self._ntotal)
79+
80+
if tell_pending:
81+
for i, p in zip(indices, points):
82+
self.tell_pending((i, p))
83+
84+
return points, loss_improvements
85+
86+
def _get_data(self):
87+
return self.data
88+
89+
def _set_data(self, data):
90+
if data:
91+
indices, values = zip(*data.items())
92+
# the points aren't used by tell, so we can safely pass None
93+
points = [(i, None) for i in indices]
94+
self.tell_many(points, values)
95+
96+
def loss(self, real=True):
97+
if not (self._to_do_indices or self.pending_points):
98+
return 0
99+
else:
100+
npoints = self.npoints + (0 if real else len(self.pending_points))
101+
return (self._ntotal - npoints) / self._ntotal
102+
103+
def remove_unfinished(self):
104+
for i in self.pending_points:
105+
self._to_do_indices.add(i)
106+
self.pending_points = set()
107+
108+
def tell(self, point, value):
109+
index, point = point
110+
self.data[index] = value
111+
self.pending_points.discard(index)
112+
self._to_do_indices.discard(index)
113+
114+
def tell_pending(self, point):
115+
index, point = point
116+
self.pending_points.add(index)
117+
self._to_do_indices.discard(index)
118+
119+
def done(self):
120+
return not self._to_do_indices and not self.pending_points
121+
122+
def result(self):
123+
"""Get the function values in the same order as ``sequence``."""
124+
if not self.done():
125+
raise Exception("Learner is not yet complete.")
126+
return list(self.data.values())
127+
128+
@property
129+
def npoints(self):
130+
return len(self.data)

adaptive/tests/test_learners.py

+47-12
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Learner1D,
2525
Learner2D,
2626
LearnerND,
27+
SequenceLearner,
2728
)
2829
from adaptive.runner import simple
2930

@@ -116,26 +117,30 @@ def quadratic(x, m: uniform(0, 10), b: uniform(0, 1)):
116117

117118

118119
@learn_with(Learner1D, bounds=(-1, 1))
120+
@learn_with(SequenceLearner, sequence=np.linspace(-1, 1, 201))
119121
def linear_with_peak(x, d: uniform(-1, 1)):
120122
a = 0.01
121123
return x + a ** 2 / (a ** 2 + (x - d) ** 2)
122124

123125

124126
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1)))
125127
@learn_with(Learner2D, bounds=((-1, 1), (-1, 1)))
128+
@learn_with(SequenceLearner, sequence=np.random.rand(1000, 2))
126129
def ring_of_fire(xy, d: uniform(0.2, 1)):
127130
a = 0.2
128131
x, y = xy
129132
return x + math.exp(-(x ** 2 + y ** 2 - d ** 2) ** 2 / a ** 4)
130133

131134

132135
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1), (-1, 1)))
136+
@learn_with(SequenceLearner, sequence=np.random.rand(1000, 3))
133137
def sphere_of_fire(xyz, d: uniform(0.2, 1)):
134138
a = 0.2
135139
x, y, z = xyz
136140
return x + math.exp(-(x ** 2 + y ** 2 + z ** 2 - d ** 2) ** 2 / a ** 4) + z ** 2
137141

138142

143+
@learn_with(SequenceLearner, sequence=range(1000))
139144
@learn_with(AverageLearner, rtol=1)
140145
def gaussian(n):
141146
return random.gauss(0, 1)
@@ -247,7 +252,7 @@ def f(x):
247252
simple(learner, goal=lambda l: l.npoints > 10)
248253

249254

250-
@run_with(Learner1D, Learner2D, LearnerND)
255+
@run_with(Learner1D, Learner2D, LearnerND, SequenceLearner)
251256
def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
252257
"""Adding already existing data is an idempotent operation.
253258
@@ -264,7 +269,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
264269
N = random.randint(10, 30)
265270
control.ask(N)
266271
xs, _ = learner.ask(N)
267-
points = [(x, f(x)) for x in xs]
272+
points = [(x, learner.function(x)) for x in xs]
268273

269274
for p in points:
270275
control.tell(*p)
@@ -277,13 +282,24 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
277282
M = random.randint(10, 30)
278283
pls = zip(*learner.ask(M))
279284
cpls = zip(*control.ask(M))
280-
# Point ordering is not defined, so compare as sets
281-
assert set(pls) == set(cpls)
285+
if learner_type is SequenceLearner:
286+
# The SequenceLearner's points might not be hasable
287+
points, values = zip(*pls)
288+
indices, points = zip(*points)
289+
290+
cpoints, cvalues = zip(*cpls)
291+
cindices, cpoints = zip(*cpoints)
292+
assert (np.array(points) == np.array(cpoints)).all()
293+
assert values == cvalues
294+
assert indices == cindices
295+
else:
296+
# Point ordering is not defined, so compare as sets
297+
assert set(pls) == set(cpls)
282298

283299

284300
# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
285301
# but we xfail it now, as Learner2D will be deprecated anyway
286-
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner)
302+
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner, SequenceLearner)
287303
def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
288304
"""Adding data for a point that was not returned by 'ask'."""
289305
# XXX: learner, control and bounds are not defined
@@ -300,17 +316,29 @@ def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
300316
N = random.randint(10, 30)
301317
xs, _ = control.ask(N)
302318

303-
ys = [f(x) for x in xs]
319+
ys = [learner.function(x) for x in xs]
304320
for x, y in zip(xs, ys):
305321
control.tell(x, y)
306322
learner.tell(x, y)
307323

308324
M = random.randint(10, 30)
309325
pls = zip(*learner.ask(M))
310326
cpls = zip(*control.ask(M))
311-
# Point ordering within a single call to 'ask'
312-
# is not guaranteed to be the same by the API.
313-
assert set(pls) == set(cpls)
327+
328+
if learner_type is SequenceLearner:
329+
# The SequenceLearner's points might not be hasable
330+
points, values = zip(*pls)
331+
indices, points = zip(*points)
332+
333+
cpoints, cvalues = zip(*cpls)
334+
cindices, cpoints = zip(*cpoints)
335+
assert (np.array(points) == np.array(cpoints)).all()
336+
assert values == cvalues
337+
assert indices == cindices
338+
else:
339+
# Point ordering within a single call to 'ask'
340+
# is not guaranteed to be the same by the API.
341+
assert set(pls) == set(cpls)
314342

315343

316344
@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND), AverageLearner)
@@ -334,7 +362,7 @@ def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
334362
N = random.randint(10, 30)
335363
control.ask(N)
336364
xs, _ = learner.ask(N)
337-
points = [(x, f(x)) for x in xs]
365+
points = [(x, learner.function(x)) for x in xs]
338366

339367
for p in points:
340368
control.tell(*p)
@@ -366,7 +394,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(
366394
xs, loss_improvements = learner.ask(N)
367395

368396
for x in xs:
369-
learner.tell(x, f(x))
397+
learner.tell(x, learner.function(x))
370398

371399
M = random.randint(50, 100)
372400
_, loss_improvements = learner.ask(M)
@@ -429,7 +457,12 @@ def test_learner_performance_is_invariant_under_scaling(
429457

430458

431459
@run_with(
432-
Learner1D, Learner2D, LearnerND, AverageLearner, with_all_loss_functions=False
460+
Learner1D,
461+
Learner2D,
462+
LearnerND,
463+
AverageLearner,
464+
SequenceLearner,
465+
with_all_loss_functions=False,
433466
)
434467
def test_balancing_learner(learner_type, f, learner_kwargs):
435468
"""Test if the BalancingLearner works with the different types of learners."""
@@ -474,6 +507,7 @@ def test_balancing_learner(learner_type, f, learner_kwargs):
474507
AverageLearner,
475508
maybe_skip(SKOptLearner),
476509
IntegratorLearner,
510+
SequenceLearner,
477511
with_all_loss_functions=False,
478512
)
479513
def test_saving(learner_type, f, learner_kwargs):
@@ -504,6 +538,7 @@ def test_saving(learner_type, f, learner_kwargs):
504538
AverageLearner,
505539
maybe_skip(SKOptLearner),
506540
IntegratorLearner,
541+
SequenceLearner,
507542
with_all_loss_functions=False,
508543
)
509544
def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
adaptive.SequenceLearner
2+
========================
3+
4+
.. autoclass:: adaptive.SequenceLearner
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

docs/source/reference/adaptive.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Learners
1414
adaptive.learner.learner1D
1515
adaptive.learner.learner2D
1616
adaptive.learner.learnerND
17+
adaptive.learner.sequence_learner
1718
adaptive.learner.skopt_learner
1819

1920
Runners

0 commit comments

Comments
 (0)