Skip to content

Commit d25f3d1

Browse files
committed
set self._cycle in the strategy setter
1 parent beadd5f commit d25f3d1

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class BalancingLearner(BaseLearner):
5151
function : callable
5252
A function that calls the functions of the underlying learners.
5353
Its signature is ``function(learner_index, point)``.
54-
strategy : 'loss_improvements' (default), 'loss', or 'npoints'
54+
strategy : 'loss_improvements' (default), 'loss', 'npoints', or 'cycle'.
5555
The points that the `BalancingLearner` choses can be either based on:
5656
the best 'loss_improvements', the smallest total 'loss' of the
5757
child learners, the number of points per learner, using 'npoints',
@@ -112,6 +112,7 @@ def strategy(self, strategy):
112112
self._ask_and_tell = self._ask_and_tell_based_on_npoints
113113
elif strategy == "cycle":
114114
self._ask_and_tell = self._ask_and_tell_based_on_cycle
115+
self._cycle = itertools.cycle(range(len(self.learners)))
115116
else:
116117
raise ValueError(
117118
'Only strategy="loss_improvements", strategy="loss",'
@@ -179,9 +180,6 @@ def _ask_and_tell_based_on_npoints(self, n):
179180
return points, loss_improvements
180181

181182
def _ask_and_tell_based_on_cycle(self, n):
182-
if not hasattr(self, "_cycle"):
183-
self._cycle = itertools.cycle(range(len(self.learners)))
184-
185183
points, loss_improvements = [], []
186184
for _ in range(n):
187185
index = next(self._cycle)

0 commit comments

Comments
 (0)