Skip to content

Commit 55c9362

Browse files
committed
[DRAFT, Example] Add MCTS example
ghstack-source-id: 5dc5cbd Pull Request resolved: #2796
1 parent 10e2f69 commit 55c9362

File tree

2 files changed

+227
-0
lines changed

2 files changed

+227
-0
lines changed

examples/trees/mcts.py

+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
import torchrl
8+
from tensordict import TensorDict
9+
10+
pgn_or_fen = "fen"
11+
12+
env = torchrl.envs.ChessEnv(
13+
include_pgn=False,
14+
include_fen=True,
15+
include_hash=True,
16+
include_hash_inv=True,
17+
include_san=True,
18+
stateful=True,
19+
mask_actions=True,
20+
)
21+
22+
23+
def transform_reward(td):
24+
if "reward" not in td:
25+
return td
26+
reward = td["reward"]
27+
if reward == 0.5:
28+
td["reward"] = 0
29+
elif reward == 1 and td["turn"]:
30+
td["reward"] = -td["reward"]
31+
return td
32+
33+
34+
# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
35+
# Need to transform the reward to be:
36+
# white win = 1
37+
# draw = 0
38+
# black win = -1
39+
env.append_transform(transform_reward)
40+
41+
forest = torchrl.data.MCTSForest()
42+
forest.reward_keys = env.reward_keys + ["_visits", "_reward_sum"]
43+
forest.done_keys = env.done_keys
44+
forest.action_keys = env.action_keys
45+
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"]
46+
47+
C = 2.0**0.5
48+
49+
50+
def traversal_priority_UCB1(tree):
51+
if tree.rollout[-1]["next", "_visits"] == 0:
52+
res = float("inf")
53+
else:
54+
if tree.parent.rollout is None:
55+
parent_visits = 0
56+
for child in tree.parent.subtree:
57+
parent_visits += child.rollout[-1]["next", "_visits"]
58+
else:
59+
parent_visits = tree.parent.rollout[-1]["next", "_visits"]
60+
assert parent_visits > 0
61+
62+
value_avg = (
63+
tree.rollout[-1]["next", "_reward_sum"]
64+
/ tree.rollout[-1]["next", "_visits"]
65+
)
66+
67+
# If it's black's turn, flip the reward, since black wants to optimize
68+
# for the lowest reward.
69+
if not tree.rollout[0]["turn"]:
70+
value_avg = -value_avg
71+
72+
res = (
73+
value_avg
74+
+ C
75+
* torch.sqrt(torch.log(parent_visits) / tree.rollout[-1]["next", "_visits"])
76+
).item()
77+
78+
return res
79+
80+
81+
def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps):
82+
done = False
83+
trees_visited = []
84+
85+
while not done:
86+
if tree.subtree is None:
87+
td_tree = tree.rollout[-1]["next"]
88+
89+
if (td_tree["_visits"] > 0 or tree.parent is None) and not td_tree["done"]:
90+
actions = env.all_actions(td_tree)
91+
subtrees = []
92+
93+
for action in actions:
94+
td = env.step(env.reset(td_tree.clone()).update(action)).update(
95+
TensorDict(
96+
{
97+
("next", "_visits"): 0,
98+
("next", "_reward_sum"): env.reward_spec.zeros(),
99+
}
100+
)
101+
)
102+
103+
new_node = torchrl.data.Tree(
104+
rollout=td.unsqueeze(0),
105+
node_data=td["next"].select(*forest.node_map.in_keys),
106+
)
107+
subtrees.append(new_node)
108+
109+
tree.subtree = TensorDict.lazy_stack(subtrees)
110+
chosen_idx = torch.randint(0, len(subtrees), ()).item()
111+
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]
112+
113+
else:
114+
rollout_state = td_tree
115+
116+
if rollout_state["done"]:
117+
rollout_reward = rollout_state["reward"]
118+
else:
119+
rollout = env.rollout(
120+
max_steps=max_rollout_steps,
121+
tensordict=rollout_state,
122+
)
123+
rollout_reward = rollout[-1]["next", "reward"]
124+
done = True
125+
126+
else:
127+
priorities = torch.tensor(
128+
[traversal_priority_UCB1(subtree) for subtree in tree.subtree]
129+
)
130+
chosen_idx = torch.argmax(priorities).item()
131+
tree = tree.subtree[chosen_idx]
132+
trees_visited.append(tree)
133+
134+
for tree in trees_visited:
135+
td = tree.rollout[-1]["next"]
136+
td["_visits"] += 1
137+
td["_reward_sum"] += rollout_reward
138+
139+
140+
def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps):
141+
"""Performs Monte-Carlo tree search in an environment.
142+
143+
Args:
144+
forest (MCTSForest): Forest of the tree to update. If the tree does not
145+
exist yet, it is added.
146+
root (TensorDict): The root step of the tree to update.
147+
env (EnvBase): Environment to performs actions in.
148+
num_steps (int): Number of iterations to traverse.
149+
max_rollout_steps (int): Maximum number of steps for each rollout.
150+
"""
151+
if root not in forest:
152+
for action in env.all_actions(root.clone()):
153+
td = env.step(env.reset(root.clone()).update(action)).update(
154+
TensorDict(
155+
{
156+
("next", "_visits"): 0,
157+
("next", "_reward_sum"): env.reward_spec.zeros(),
158+
}
159+
)
160+
)
161+
forest.extend(td.unsqueeze(0))
162+
163+
tree = forest.get_tree(root)
164+
165+
for _ in range(num_steps):
166+
_traverse_MCTS_one_step(forest, tree, env, max_rollout_steps)
167+
168+
return tree
169+
170+
171+
def tree_format_fn(tree):
172+
td = tree.rollout[-1]["next"]
173+
return [
174+
td["san"],
175+
td[pgn_or_fen].split("\n")[-1],
176+
td["_reward_sum"].item(),
177+
td["_visits"].item(),
178+
]
179+
180+
181+
def get_best_move(fen, mcts_steps, rollout_steps):
182+
root = env.reset(TensorDict({"fen": fen}))
183+
tree = traverse_MCTS(forest, root, env, mcts_steps, rollout_steps)
184+
185+
# print('------------------------------')
186+
# print(tree.to_string(tree_format_fn))
187+
# print('------------------------------')
188+
189+
moves = []
190+
191+
for subtree in tree.subtree:
192+
san = subtree.rollout[0]["next", "san"]
193+
reward_sum = subtree.rollout[-1]["next", "_reward_sum"]
194+
visits = subtree.rollout[-1]["next", "_visits"]
195+
value_avg = (reward_sum / visits).item()
196+
if not subtree.rollout[0]["turn"]:
197+
value_avg = -value_avg
198+
moves.append((value_avg, san))
199+
200+
moves = sorted(moves, key=lambda x: -x[0])
201+
202+
print("------------------")
203+
for value_avg, san in moves:
204+
print(f" {value_avg:0.02f} {san}")
205+
print("------------------")
206+
207+
return moves[0][1]
208+
209+
210+
# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
211+
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
212+
assert get_best_move(fen0, 100, 10) == "Rd8#"
213+
214+
# Black has M1, best move Qg6#. Other moves give rough equality or worse.
215+
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
216+
assert get_best_move(fen1, 100, 10) == "Qg6#"
217+
218+
# White has M2, best move Rxg8+. Any other move loses.
219+
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
220+
assert get_best_move(fen2, 1000, 10) == "Rxg8+"

torchrl/data/map/tree.py

+7
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ def maybe_flatten_list(maybe_nested_list):
330330
return TensorDict.lazy_stack(
331331
[self._from_tensordict(r) for r in parent_result]
332332
)
333+
if parent_result is None:
334+
return None
333335
return self._from_tensordict(parent_result)
334336

335337
@property
@@ -1227,6 +1229,11 @@ def valid_paths(cls, tree: Tree):
12271229
def __len__(self):
12281230
return len(self.data_map)
12291231

1232+
def __contains__(self, root: TensorDictBase):
1233+
if self.node_map is None:
1234+
return False
1235+
return root.select(*self.node_map.in_keys) in self.node_map
1236+
12301237
def to_string(self, td_root, node_format_fn):
12311238
"""Generates a string representation of a tree in the forest.
12321239

0 commit comments

Comments
 (0)