@@ -51,7 +51,7 @@ class BalancingLearner(BaseLearner):
51
51
function : callable
52
52
A function that calls the functions of the underlying learners.
53
53
Its signature is ``function(learner_index, point)``.
54
- strategy : 'loss_improvements' (default), 'loss', or 'npoints'
54
+ strategy : 'loss_improvements' (default), 'loss', 'npoints', or 'cycle'.
55
55
The points that the `BalancingLearner` choses can be either based on:
56
56
the best 'loss_improvements', the smallest total 'loss' of the
57
57
child learners, the number of points per learner, using 'npoints',
@@ -112,6 +112,7 @@ def strategy(self, strategy):
112
112
self ._ask_and_tell = self ._ask_and_tell_based_on_npoints
113
113
elif strategy == "cycle" :
114
114
self ._ask_and_tell = self ._ask_and_tell_based_on_cycle
115
+ self ._cycle = itertools .cycle (range (len (self .learners )))
115
116
else :
116
117
raise ValueError (
117
118
'Only strategy="loss_improvements", strategy="loss",'
@@ -179,9 +180,6 @@ def _ask_and_tell_based_on_npoints(self, n):
179
180
return points , loss_improvements
180
181
181
182
def _ask_and_tell_based_on_cycle (self , n ):
182
- if not hasattr (self , "_cycle" ):
183
- self ._cycle = itertools .cycle (range (len (self .learners )))
184
-
185
183
points , loss_improvements = [], []
186
184
for _ in range (n ):
187
185
index = next (self ._cycle )
0 commit comments