@@ -119,20 +119,28 @@ def strategy(self, strategy):
119119 ' strategy="npoints", or strategy="cycle" is implemented.'
120120 )
121121
122+ def _to_select (self , total_points ):
123+ to_select = []
124+ for index , learner in enumerate (self .learners ):
125+ # Take the points from the cache
126+ if index not in self ._ask_cache :
127+ self ._ask_cache [index ] = learner .ask (n = 1 , tell_pending = False )
128+ points , loss_improvements = self ._ask_cache [index ]
129+ if not points :
130+ # cannot ask for more points
131+ return to_select
132+ to_select .append (
133+ ((index , points [0 ]), (loss_improvements [0 ], - total_points [index ]))
134+ )
135+ return to_select
136+
122137 def _ask_and_tell_based_on_loss_improvements (self , n ):
123138 selected = [] # tuples ((learner_index, point), loss_improvement)
124139 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
125140 for _ in range (n ):
126- to_select = []
127- for index , learner in enumerate (self .learners ):
128- # Take the points from the cache
129- if index not in self ._ask_cache :
130- self ._ask_cache [index ] = learner .ask (n = 1 , tell_pending = False )
131- points , loss_improvements = self ._ask_cache [index ]
132- to_select .append (
133- ((index , points [0 ]), (loss_improvements [0 ], - total_points [index ]))
134- )
135-
141+ to_select = self ._to_select (total_points )
142+ if not to_select :
143+ break
136144 # Choose the optimal improvement.
137145 (index , point ), (loss_improvement , _ ) = max (to_select , key = itemgetter (1 ))
138146 total_points [index ] += 1
0 commit comments