Skip to content

Commit 578ba4b

Browse files
Added UCB-1 Tuned
1 parent 168c293 commit 578ba4b

File tree

1 file changed

+80
-1
lines changed

1 file changed

+80
-1
lines changed

torchrl/modules/mcts/scores.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,88 @@ def update_weights(
203203
node.set(self.weights_key, weights)
204204

205205

206+
class UCB1TunedScore(MCTSScore):
207+
def __init__(
208+
self,
209+
*,
210+
win_count_key: NestedKey = "win_count",
211+
visits_key: NestedKey = "visits",
212+
total_visits_key: NestedKey = "total_visits",
213+
sum_squared_rewards_key: NestedKey = "sum_squared_rewards",
214+
score_key: NestedKey = "score",
215+
exploration_constant: float = 2.0,
216+
):
217+
super().__init__()
218+
self.win_count_key = win_count_key
219+
self.visits_key = visits_key
220+
self.total_visits_key = total_visits_key
221+
self.sum_squared_rewards_key = sum_squared_rewards_key
222+
self.score_key = score_key
223+
self.exploration_constant = exploration_constant
224+
225+
self.in_keys = [
226+
self.win_count_key,
227+
self.visits_key,
228+
self.total_visits_key,
229+
self.sum_squared_rewards_key,
230+
]
231+
self.out_keys = [self.score_key]
232+
233+
def forward(self, node: TensorDictBase) -> TensorDictBase:
234+
q_sum_i = node.get(self.win_count_key)
235+
n_i = node.get(self.visits_key)
236+
n_parent = node.get(self.total_visits_key)
237+
sum_sq_rewards_i = node.get(self.sum_squared_rewards_key)
238+
239+
if n_parent.ndim > 0 and n_parent.ndim < q_sum_i.ndim:
240+
n_parent_expanded = n_parent.unsqueeze(-1)
241+
else:
242+
n_parent_expanded = n_parent
243+
244+
safe_n_parent_for_log = torch.clamp(n_parent_expanded, min=1.0)
245+
log_n_parent = torch.log(safe_n_parent_for_log)
246+
247+
scores = torch.zeros_like(q_sum_i, device=q_sum_i.device)
248+
249+
visited_mask = n_i > 0
250+
251+
if torch.any(visited_mask):
252+
q_sum_i_v = q_sum_i[visited_mask]
253+
n_i_v = n_i[visited_mask]
254+
sum_sq_rewards_i_v = sum_sq_rewards_i[visited_mask]
255+
256+
log_n_parent_v = log_n_parent.expand_as(n_i)[visited_mask]
257+
258+
avg_reward_i_v = q_sum_i_v / n_i_v
259+
260+
empirical_variance_v = (sum_sq_rewards_i_v / n_i_v) - avg_reward_i_v.pow(2)
261+
bias_correction_v = (
262+
self.exploration_constant * log_n_parent_v / n_i_v
263+
).sqrt()
264+
265+
v_i_v = empirical_variance_v + bias_correction_v
266+
v_i_v.clamp(min=0)
267+
268+
min_variance_term_v = torch.min(torch.full_like(v_i_v, 0.25), v_i_v)
269+
exploration_component_v = (
270+
log_n_parent_v / n_i_v * min_variance_term_v
271+
).sqrt()
272+
273+
scores[visited_mask] = avg_reward_i_v + exploration_component_v
274+
275+
unvisited_mask = ~visited_mask
276+
if torch.any(unvisited_mask):
277+
scores[unvisited_mask] = torch.finfo(scores.dtype).max / 10.0
278+
279+
node.set(self.score_key, scores)
280+
return node
281+
282+
206283
class MCTSScores(Enum):
207284
PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value
208285
UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002
209-
UCB1_TUNED = "UCB1-Tuned"
286+
UCB1_TUNED = functools.partial(
287+
UCB1TunedScore, exploration_constant=2.0
288+
) # Auer et al. (2002) C=2 for rewards in [0,1]
210289
EXP3 = functools.partial(EXP3Score, gamma=0.1)
211290
PUCT_VARIANT = "PUCT-Variant"

0 commit comments

Comments
 (0)