|
3 | 3 | import itertools
|
4 | 4 | import sys
|
5 | 5 | from collections import defaultdict
|
6 |
| -from collections.abc import Iterable, Sequence |
| 6 | +from collections.abc import Generator, Iterable, Sequence |
7 | 7 | from contextlib import suppress
|
8 | 8 | from functools import partial
|
9 | 9 | from operator import itemgetter
|
@@ -126,11 +126,10 @@ def __init__(
|
126 | 126 | self._cdims_default = cdims
|
127 | 127 |
|
128 | 128 | if len({learner.__class__ for learner in self.learners}) > 1:
|
129 |
| - raise TypeError( |
130 |
| - "A BalacingLearner can handle only one type" " of learners." |
131 |
| - ) |
| 129 | + raise TypeError("A BalacingLearner can handle only one type of learners.") |
132 | 130 |
|
133 | 131 | self.strategy: STRATEGY_TYPE = strategy
|
| 132 | + self._gen: Generator | None = None |
134 | 133 |
|
135 | 134 | def new(self) -> BalancingLearner:
|
136 | 135 | """Create a new `BalancingLearner` with the same parameters."""
|
@@ -288,27 +287,19 @@ def _ask_and_tell_based_on_cycle(
|
288 | 287 | def _ask_and_tell_based_on_sequential(
|
289 | 288 | self, n: int
|
290 | 289 | ) -> tuple[list[tuple[Int, Any]], list[float]]:
|
| 290 | + if self._gen is None: |
| 291 | + self._gen = _sequential_generator(self.learners) |
| 292 | + |
291 | 293 | points: list[tuple[Int, Any]] = []
|
292 | 294 | loss_improvements: list[float] = []
|
293 | 295 | learner_index = 0
|
294 | 296 |
|
295 |
| - while len(points) < n: |
296 |
| - learner = self.learners[learner_index] |
297 |
| - if learner.done(): # type: ignore[attr-defined] |
298 |
| - if learner_index == len(self.learners) - 1: |
299 |
| - break |
300 |
| - learner_index += 1 |
301 |
| - continue |
302 |
| - |
303 |
| - point, loss_improvement = learner.ask(n=1) |
304 |
| - if not point: # if learner is exhausted, we don't get points |
305 |
| - if learner_index == len(self.learners) - 1: |
306 |
| - break |
307 |
| - learner_index += 1 |
308 |
| - continue |
309 |
| - points.append((learner_index, point[0])) |
310 |
| - loss_improvements.append(loss_improvement[0]) |
311 |
| - self.tell_pending((learner_index, point[0])) |
| 297 | + for learner_index, point, loss_improvement in self._gen: |
| 298 | + points.append((learner_index, point)) |
| 299 | + loss_improvements.append(loss_improvement) |
| 300 | + self.tell_pending((learner_index, point)) |
| 301 | + if len(points) >= n: |
| 302 | + break |
312 | 303 |
|
313 | 304 | return points, loss_improvements
|
314 | 305 |
|
@@ -629,3 +620,27 @@ def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
|
629 | 620 | def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
|
630 | 621 | learners, cdims, strategy = state
|
631 | 622 | self.__init__(learners, cdims=cdims, strategy=strategy) # type: ignore[misc]
|
| 623 | + |
| 624 | + |
| 625 | +def _sequential_generator( |
| 626 | + learners: list[BaseLearner], |
| 627 | +) -> Generator[tuple[int, Any, float], None, None]: |
| 628 | + learner_index = 0 |
| 629 | + if not hasattr(learners[0], "done"): |
| 630 | + msg = "All learners must have a `done` method to use the 'sequential' strategy." |
| 631 | + raise ValueError(msg) |
| 632 | + while True: |
| 633 | + learner = learners[learner_index] |
| 634 | + if learner.done(): # type: ignore[attr-defined] |
| 635 | + if learner_index == len(learners) - 1: |
| 636 | + return |
| 637 | + learner_index += 1 |
| 638 | + continue |
| 639 | + |
| 640 | + point, loss_improvement = learner.ask(n=1) |
| 641 | + if not point: # if learner is exhausted, we don't get points |
| 642 | + if learner_index == len(learners) - 1: |
| 643 | + return |
| 644 | + learner_index += 1 |
| 645 | + continue |
| 646 | + yield learner_index, point[0], loss_improvement[0] |
0 commit comments