@@ -203,9 +203,88 @@ def update_weights(
203
203
node .set (self .weights_key , weights )
204
204
205
205
206
+ class UCB1TunedScore (MCTSScore ):
207
+ def __init__ (
208
+ self ,
209
+ * ,
210
+ win_count_key : NestedKey = "win_count" ,
211
+ visits_key : NestedKey = "visits" ,
212
+ total_visits_key : NestedKey = "total_visits" ,
213
+ sum_squared_rewards_key : NestedKey = "sum_squared_rewards" ,
214
+ score_key : NestedKey = "score" ,
215
+ exploration_constant : float = 2.0 ,
216
+ ):
217
+ super ().__init__ ()
218
+ self .win_count_key = win_count_key
219
+ self .visits_key = visits_key
220
+ self .total_visits_key = total_visits_key
221
+ self .sum_squared_rewards_key = sum_squared_rewards_key
222
+ self .score_key = score_key
223
+ self .exploration_constant = exploration_constant
224
+
225
+ self .in_keys = [
226
+ self .win_count_key ,
227
+ self .visits_key ,
228
+ self .total_visits_key ,
229
+ self .sum_squared_rewards_key ,
230
+ ]
231
+ self .out_keys = [self .score_key ]
232
+
233
+ def forward (self , node : TensorDictBase ) -> TensorDictBase :
234
+ q_sum_i = node .get (self .win_count_key )
235
+ n_i = node .get (self .visits_key )
236
+ n_parent = node .get (self .total_visits_key )
237
+ sum_sq_rewards_i = node .get (self .sum_squared_rewards_key )
238
+
239
+ if n_parent .ndim > 0 and n_parent .ndim < q_sum_i .ndim :
240
+ n_parent_expanded = n_parent .unsqueeze (- 1 )
241
+ else :
242
+ n_parent_expanded = n_parent
243
+
244
+ safe_n_parent_for_log = torch .clamp (n_parent_expanded , min = 1.0 )
245
+ log_n_parent = torch .log (safe_n_parent_for_log )
246
+
247
+ scores = torch .zeros_like (q_sum_i , device = q_sum_i .device )
248
+
249
+ visited_mask = n_i > 0
250
+
251
+ if torch .any (visited_mask ):
252
+ q_sum_i_v = q_sum_i [visited_mask ]
253
+ n_i_v = n_i [visited_mask ]
254
+ sum_sq_rewards_i_v = sum_sq_rewards_i [visited_mask ]
255
+
256
+ log_n_parent_v = log_n_parent .expand_as (n_i )[visited_mask ]
257
+
258
+ avg_reward_i_v = q_sum_i_v / n_i_v
259
+
260
+ empirical_variance_v = (sum_sq_rewards_i_v / n_i_v ) - avg_reward_i_v .pow (2 )
261
+ bias_correction_v = (
262
+ self .exploration_constant * log_n_parent_v / n_i_v
263
+ ).sqrt ()
264
+
265
+ v_i_v = empirical_variance_v + bias_correction_v
266
+ v_i_v .clamp (min = 0 )
267
+
268
+ min_variance_term_v = torch .min (torch .full_like (v_i_v , 0.25 ), v_i_v )
269
+ exploration_component_v = (
270
+ log_n_parent_v / n_i_v * min_variance_term_v
271
+ ).sqrt ()
272
+
273
+ scores [visited_mask ] = avg_reward_i_v + exploration_component_v
274
+
275
+ unvisited_mask = ~ visited_mask
276
+ if torch .any (unvisited_mask ):
277
+ scores [unvisited_mask ] = torch .finfo (scores .dtype ).max / 10.0
278
+
279
+ node .set (self .score_key , scores )
280
+ return node
281
+
282
+
206
283
class MCTSScores (Enum ):
207
284
PUCT = functools .partial (PUCTScore , c = 5 ) # AlphaGo default value
208
285
UCB = functools .partial (UCBScore , c = math .sqrt (2 )) # default from Auer et al. 2002
209
- UCB1_TUNED = "UCB1-Tuned"
286
+ UCB1_TUNED = functools .partial (
287
+ UCB1TunedScore , exploration_constant = 2.0
288
+ ) # Auer et al. (2002) C=2 for rewards in [0,1]
210
289
EXP3 = functools .partial (EXP3Score , gamma = 0.1 )
211
290
PUCT_VARIANT = "PUCT-Variant"
0 commit comments