Skip to content

Commit 4209e85

Browse files
committed
Implement using generator
1 parent 3523343 commit 4209e85

File tree

1 file changed

+36
-21
lines changed

1 file changed

+36
-21
lines changed

adaptive/learner/balancing_learner.py

+36-21
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
import sys
55
from collections import defaultdict
6-
from collections.abc import Iterable, Sequence
6+
from collections.abc import Generator, Iterable, Sequence
77
from contextlib import suppress
88
from functools import partial
99
from operator import itemgetter
@@ -126,11 +126,10 @@ def __init__(
126126
self._cdims_default = cdims
127127

128128
if len({learner.__class__ for learner in self.learners}) > 1:
129-
raise TypeError(
130-
"A BalacingLearner can handle only one type" " of learners."
131-
)
129+
raise TypeError("A BalacingLearner can handle only one type of learners.")
132130

133131
self.strategy: STRATEGY_TYPE = strategy
132+
self._gen: Generator | None = None
134133

135134
def new(self) -> BalancingLearner:
136135
"""Create a new `BalancingLearner` with the same parameters."""
@@ -288,27 +287,19 @@ def _ask_and_tell_based_on_cycle(
288287
def _ask_and_tell_based_on_sequential(
289288
self, n: int
290289
) -> tuple[list[tuple[Int, Any]], list[float]]:
290+
if self._gen is None:
291+
self._gen = _sequential_generator(self.learners)
292+
291293
points: list[tuple[Int, Any]] = []
292294
loss_improvements: list[float] = []
293295
learner_index = 0
294296

295-
while len(points) < n:
296-
learner = self.learners[learner_index]
297-
if learner.done(): # type: ignore[attr-defined]
298-
if learner_index == len(self.learners) - 1:
299-
break
300-
learner_index += 1
301-
continue
302-
303-
point, loss_improvement = learner.ask(n=1)
304-
if not point: # if learner is exhausted, we don't get points
305-
if learner_index == len(self.learners) - 1:
306-
break
307-
learner_index += 1
308-
continue
309-
points.append((learner_index, point[0]))
310-
loss_improvements.append(loss_improvement[0])
311-
self.tell_pending((learner_index, point[0]))
297+
for learner_index, point, loss_improvement in self._gen:
298+
points.append((learner_index, point))
299+
loss_improvements.append(loss_improvement)
300+
self.tell_pending((learner_index, point))
301+
if len(points) >= n:
302+
break
312303

313304
return points, loss_improvements
314305

@@ -629,3 +620,27 @@ def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
629620
def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
630621
learners, cdims, strategy = state
631622
self.__init__(learners, cdims=cdims, strategy=strategy) # type: ignore[misc]
623+
624+
625+
def _sequential_generator(
626+
learners: list[BaseLearner],
627+
) -> Generator[tuple[int, Any, float], None, None]:
628+
learner_index = 0
629+
if not hasattr(learners[0], "done"):
630+
msg = "All learners must have a `done` method to use the 'sequential' strategy."
631+
raise ValueError(msg)
632+
while True:
633+
learner = learners[learner_index]
634+
if learner.done(): # type: ignore[attr-defined]
635+
if learner_index == len(learners) - 1:
636+
return
637+
learner_index += 1
638+
continue
639+
640+
point, loss_improvement = learner.ask(n=1)
641+
if not point: # if learner is exhausted, we don't get points
642+
if learner_index == len(learners) - 1:
643+
return
644+
learner_index += 1
645+
continue
646+
yield learner_index, point[0], loss_improvement[0]

0 commit comments

Comments
 (0)